Removed saved checkpoints

This commit is contained in:
Jordon Brooks 2023-07-24 16:49:36 +01:00
parent c7306a9d48
commit 58fcf819ee

View file

@ -8,7 +8,6 @@ from video_compression_model import VideoCompressionModel
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
BATCH_SIZE = 32 # Batch size used during training
EPOCHS = 20 # Number of training epochs
CHECKPOINT_FILEPATH = "models/checkpoint-{epoch:02d}.keras"
# Step 1: Data Preparation
TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name
@ -68,21 +67,12 @@ val_targets = val_frames
# Create the "models" directory if it doesn't exist
os.makedirs("models", exist_ok=True)
# Create the ModelCheckpoint callback
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=CHECKPOINT_FILEPATH,
save_weights_only=False, # Save the entire model (including architecture)
monitor='val_loss', # Metric to monitor for saving the best model (optional)
save_best_only=True # Save only the best model based on the monitored metric (optional)
)
print("\nTraining the model...")
model.fit(
train_frames, [train_targets, tf.zeros_like(train_targets)],
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)]),
callbacks=[model_checkpoint_callback] # Add the ModelCheckpoint callback
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)])
)
print("\nTraining completed.")