# train_model.py """ TODO: - Add more different videos with different parateters into the training set. - Add different scenes with the same parameters """ import argparse import json import os from featureExtraction import psnr os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import gc import tensorflow as tf from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) from video_compression_model import VideoCompressionModel, data_generator from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER # Constants BATCH_SIZE = 16 EPOCHS = 100 LEARNING_RATE = 0.001 DECAY_STEPS = 40 DECAY_RATE = 0.9 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 5 def save_model(model): try: LOGGER.debug("Attempting to save the model.") os.makedirs("models", exist_ok=True) model.save(MODEL_SAVE_FILE, save_format='tf') LOGGER.info("Model saved successfully!") except Exception as e: LOGGER.error(f"Error saving the model: {e}") raise def load_video_metadata(list_path): """ Load video metadata from a JSON file. Args: - json_path (str): Path to the JSON file containing video metadata. Returns: - list: List of dictionaries, each containing video details. """ LOGGER.trace(f"Entering: load_video_metadata({list_path})") try: with open(list_path, "r") as json_file: file = json.load(json_file) LOGGER.trace(f"load_video_metadata returning: {file}") return file except FileNotFoundError: LOGGER.error(f"Metadata file {list_path} not found.") raise except json.JSONDecodeError: LOGGER.error(f"Error decoding JSON from {list_path}.") raise def main(): global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.') parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.') parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.') args = parser.parse_args() BATCH_SIZE = args.batch_size EPOCHS = args.epochs LEARNING_RATE = args.learning_rate MAX_FRAMES = args.max_frames DECAY_RATE = args.decay_rate DECAY_STEPS = args.decay_steps # Display training configuration LOGGER.info("Starting the training with the given configuration.") LOGGER.info("Training configuration:") LOGGER.info(f"Batch size: {BATCH_SIZE}") LOGGER.info(f"Epochs: {EPOCHS}") LOGGER.info(f"Learning rate: {LEARNING_RATE}") LOGGER.info(f"Max Frames: {MAX_FRAMES}") LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}") LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}") # Load all video metadata all_videos = load_video_metadata("test_data/validation/validation.json") # Split into training and validation split_index = int(0.8 * len(all_videos)) training_videos = all_videos[:split_index] validation_videos = all_videos[split_index:] if args.continue_training: model = tf.keras.models.load_model(args.continue_training) else: model = VideoCompressionModel() # Define exponential decay schedule lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=LEARNING_RATE, decay_steps=DECAY_STEPS, decay_rate=DECAY_RATE, staircase=False ) # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) model.compile(loss='mse', optimizer=optimizer, metrics=[psnr]) # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"), save_weights_only=False, save_best_only=False, verbose=1, save_format="tf" ) early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) # Calculate steps per epoch for training and validation if MAX_FRAMES <= 0: average_frames_per_video = 2880 # Given 2 minutes @ 24 fps else: average_frames_per_video = max(MAX_FRAMES, 0) total_frames_train = average_frames_per_video * len(training_videos) total_frames_validation = average_frames_per_video * len(validation_videos) steps_per_epoch_train = total_frames_train // BATCH_SIZE steps_per_epoch_validation = total_frames_validation // BATCH_SIZE # Train the model LOGGER.info("Starting model training.") model.fit( data_generator(training_videos, BATCH_SIZE), epochs=EPOCHS, steps_per_epoch=steps_per_epoch_train, validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here validation_steps=steps_per_epoch_validation, # Add validation steps here callbacks=[early_stop, checkpoint_callback] ) LOGGER.info("Model training completed.") save_model(model) if __name__ == "__main__": try: main() except Exception as e: LOGGER.error(f"Unexpected error during training: {e}") raise