Removed saved checkpoints
This commit is contained in:
parent
c7306a9d48
commit
58fcf819ee
1 changed files with 1 additions and 11 deletions
|
@ -8,7 +8,6 @@ from video_compression_model import VideoCompressionModel
|
|||
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
||||
BATCH_SIZE = 32 # Batch size used during training
|
||||
EPOCHS = 20 # Number of training epochs
|
||||
CHECKPOINT_FILEPATH = "models/checkpoint-{epoch:02d}.keras"
|
||||
|
||||
# Step 1: Data Preparation
|
||||
TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name
|
||||
|
@ -68,21 +67,12 @@ val_targets = val_frames
|
|||
# Create the "models" directory if it doesn't exist
|
||||
os.makedirs("models", exist_ok=True)
|
||||
|
||||
# Create the ModelCheckpoint callback
|
||||
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=CHECKPOINT_FILEPATH,
|
||||
save_weights_only=False, # Save the entire model (including architecture)
|
||||
monitor='val_loss', # Metric to monitor for saving the best model (optional)
|
||||
save_best_only=True # Save only the best model based on the monitored metric (optional)
|
||||
)
|
||||
|
||||
print("\nTraining the model...")
|
||||
model.fit(
|
||||
train_frames, [train_targets, tf.zeros_like(train_targets)],
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS,
|
||||
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)]),
|
||||
callbacks=[model_checkpoint_callback] # Add the ModelCheckpoint callback
|
||||
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)])
|
||||
)
|
||||
print("\nTraining completed.")
|
||||
|
||||
|
|
Reference in a new issue