diff --git a/train_model.py b/train_model.py index 74aaae6..bddebfd 100644 --- a/train_model.py +++ b/train_model.py @@ -90,7 +90,11 @@ def main(): training_videos = all_videos[:split_index] validation_videos = all_videos[split_index:] - model = VideoCompressionModel() + if args.continue_training: + model = tf.keras.models.load_model(args.continue_training) + else: + model = VideoCompressionModel() + # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)