From 491aaf402ed64423e8e61becc690dcf18e2d0ac4 Mon Sep 17 00:00:00 2001 From: Jordon Brooks Date: Sun, 13 Aug 2023 15:56:51 +0100 Subject: [PATCH] Max Frames now works --- train_model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/train_model.py b/train_model.py index bddebfd..e08b268 100644 --- a/train_model.py +++ b/train_model.py @@ -11,12 +11,12 @@ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from video_compression_model import VideoCompressionModel, data_generator -from globalVars import HEIGHT, WIDTH, LOGGER +from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER # Constants BATCH_SIZE = 16 -EPOCHS = 5 -LEARNING_RATE = 0.01 +EPOCHS = 100 +LEARNING_RATE = 0.001 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 @@ -58,18 +58,20 @@ def load_video_metadata(list_path): def main(): - global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE + global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') + parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.') args = parser.parse_args() BATCH_SIZE = args.batch_size EPOCHS = args.epochs LEARNING_RATE = args.learning_rate + MAX_FRAMES = args.max_frames # Display training configuration LOGGER.info("Starting the training with the given configuration.") @@ -77,10 +79,10 @@ def main(): LOGGER.info(f"Batch size: {BATCH_SIZE}") LOGGER.info(f"Epochs: {EPOCHS}") LOGGER.info(f"Learning rate: {LEARNING_RATE}") + LOGGER.info(f"Max Frames: {MAX_FRAMES}") LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}") LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}") - LOGGER.trace("Hello, World!") # Load all video metadata all_videos = load_video_metadata("test_data/validation/validation.json") @@ -111,7 +113,11 @@ def main(): early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) # Calculate steps per epoch for training and validation - average_frames_per_video = 2880 # Given 2 minutes @ 24 fps + if MAX_FRAMES <= 0: + average_frames_per_video = 2880 # Given 2 minutes @ 24 fps + else: + average_frames_per_video = max(MAX_FRAMES, 0) + total_frames_train = average_frames_per_video * len(training_videos) total_frames_validation = average_frames_per_video * len(validation_videos) steps_per_epoch_train = total_frames_train // BATCH_SIZE