update good

This commit is contained in:
Jordon Brooks 2023-08-18 01:57:24 +01:00
parent a8acf220b2
commit 9f34cf8074

View file

@ -39,7 +39,7 @@ DECAY_STEPS = 160
DECAY_RATE = 0.9 DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf" MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints" MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 5 EARLY_STOP = 10
class GarbageCollectorCallback(Callback): class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
@ -152,21 +152,10 @@ def main():
verbose=1, verbose=1,
save_format="tf" save_format="tf"
) )
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Custom garbage collection callback # Custom garbage collection callback
gc_callback = GarbageCollectorCallback() gc_callback = GarbageCollectorCallback()
# Calculate steps per epoch for training and validation
#if MAX_FRAMES <= 0:
# average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
#else:
# average_frames_per_video = max(MAX_FRAMES, 0)
#total_frames_train = average_frames_per_video * len(training_videos)
#total_frames_validation = average_frames_per_video * len(validation_videos)
#steps_per_epoch_train = total_frames_train // BATCH_SIZE
#steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
gc.collect() gc.collect()
@ -175,7 +164,7 @@ def main():
model.fit( model.fit(
training_dataset, training_dataset,
epochs=EPOCHS, epochs=EPOCHS,
validation_data=validation_dataset, # Add validation data here validation_data=validation_dataset,
callbacks=[early_stop, checkpoint_callback, gc_callback] callbacks=[early_stop, checkpoint_callback, gc_callback]
) )
LOGGER.info("Model training completed.") LOGGER.info("Model training completed.")