diff --git a/train_model.py b/train_model.py index 2494ebc..74d34fc 100644 --- a/train_model.py +++ b/train_model.py @@ -39,7 +39,7 @@ DECAY_STEPS = 160 DECAY_RATE = 0.9 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" -EARLY_STOP = 5 +EARLY_STOP = 10 class GarbageCollectorCallback(Callback): def on_epoch_end(self, epoch, logs=None): @@ -152,21 +152,10 @@ def main(): verbose=1, save_format="tf" ) - early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) + early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True) # Custom garbage collection callback gc_callback = GarbageCollectorCallback() - - # Calculate steps per epoch for training and validation - #if MAX_FRAMES <= 0: - # average_frames_per_video = 2880 # Given 2 minutes @ 24 fps - #else: - # average_frames_per_video = max(MAX_FRAMES, 0) - - #total_frames_train = average_frames_per_video * len(training_videos) - #total_frames_validation = average_frames_per_video * len(validation_videos) - #steps_per_epoch_train = total_frames_train // BATCH_SIZE - #steps_per_epoch_validation = total_frames_validation // BATCH_SIZE gc.collect() @@ -175,7 +164,7 @@ def main(): model.fit( training_dataset, epochs=EPOCHS, - validation_data=validation_dataset, # Add validation data here + validation_data=validation_dataset, callbacks=[early_stop, checkpoint_callback, gc_callback] ) LOGGER.info("Model training completed.")