# train_model.py import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import json import numpy as np import cv2 import argparse import tensorflow as tf from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from global_train import LOGGER # Constants BATCH_SIZE = 4 EPOCHS = 100 LEARNING_RATE = 0.000001 TRAIN_SAMPLES = 100 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 WIDTH = 638 HEIGHT = 360 def load_video_metadata(list_path): LOGGER.trace(f"Entering: load_video_metadata({list_path})") try: with open(list_path, "r") as json_file: file = json.load(json_file) LOGGER.trace(f"load_video_metadata returning: {file}") return file except FileNotFoundError: LOGGER.error(f"Metadata file {list_path} not found.") raise except json.JSONDecodeError: LOGGER.error(f"Error decoding JSON from {list_path}.") raise def load_video_samples(list_path, samples=TRAIN_SAMPLES): """ Load video samples from the metadata list. Args: - list_path (str): Path to the metadata JSON file. - samples (int): Number of total samples to be extracted. Returns: - list: Extracted video samples. """ LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" ) details_list = load_video_metadata(list_path) all_samples = [] num_videos = len(details_list) frames_per_video = int(samples / num_videos) LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos") for video_details in details_list: video_file = video_details["video_file"] uncompressed_video_file = video_details["uncompressed_video_file"] crf = video_details['crf'] / 63.0 preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) video_details['preset_speed'] = preset_speed compressed_frames, uncompressed_frames = [], [] try: cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file)) cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file)) if not cap.isOpened() or not cap_uncompressed.isOpened(): raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}") for _ in range(frames_per_video): ret, frame_compressed = cap.read() ret_uncompressed, frame = cap_uncompressed.read() if not ret or not ret_uncompressed: continue # Check frame dimensions and resize if necessary if frame.shape[:2] != (WIDTH, HEIGHT): LOGGER.warn(f"Resizing video: {video_file}") frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) if frame_compressed.shape[:2] != (WIDTH, HEIGHT): LOGGER.warn(f"Resizing video: {uncompressed_video_file}") frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB) uncompressed_frames.append(normalize(frame)) compressed_frames.append(normalize(frame_compressed)) all_samples.extend({ "frame": frame, "compressed_frame": frame_compressed, "crf": crf, "preset_speed": preset_speed, "video_file": video_file } for frame, frame_compressed in zip(uncompressed_frames, compressed_frames)) except Exception as e: LOGGER.error(f"Error during video sample loading: {e}") raise finally: cap.release() cap_uncompressed.release() return all_samples def normalize(frame): """ Normalize pixel values of the frame to range [0, 1]. Args: - frame (ndarray): Image frame. Returns: - ndarray: Normalized frame. """ LOGGER.trace(f"Normalizing frame") return frame / 255.0 def save_model(model): try: LOGGER.debug("Attempting to save the model.") os.makedirs("models", exist_ok=True) model.save(MODEL_SAVE_FILE, save_format='tf') LOGGER.info("Model saved successfully!") except Exception as e: LOGGER.error(f"Error saving the model: {e}") raise def main(): # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.') parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') args = parser.parse_args() # Display training configuration LOGGER.info("Starting the training with the given configuration.") LOGGER.info("Training configuration:") LOGGER.info(f"Batch size: {args.batch_size}") LOGGER.info(f"Epochs: {args.epochs}") LOGGER.info(f"Training samples: {args.training_samples}") LOGGER.info(f"Learning rate: {args.learning_rate}") LOGGER.info(f"Continue training from: {args.continue_training}") # Load training and validation samples LOGGER.debug("Loading training and validation samples.") training_samples = load_video_samples("test_data/training/training.json") validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2) train_generator = VideoDataGenerator(training_samples, args.batch_size) val_generator = VideoDataGenerator(validation_samples, args.batch_size) # Load or initialize model if args.continue_training: model = tf.keras.models.load_model(args.continue_training) else: model = VideoCompressionModel() # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) model.compile(loss='mean_squared_error', optimizer=optimizer) # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"), save_weights_only=False, save_best_only=False, verbose=1, save_format="tf" ) early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) # Train the model LOGGER.info("Starting model training.") model.fit( train_generator, steps_per_epoch=len(train_generator), epochs=args.epochs, validation_data=val_generator, validation_steps=len(val_generator), callbacks=[early_stop, checkpoint_callback] ) LOGGER.info("Model training completed.") save_model(model) if __name__ == "__main__": try: main() except Exception as e: LOGGER.error(f"Unexpected error during training: {e}") raise