update good
This commit is contained in:
parent
a8acf220b2
commit
9f34cf8074
1 changed files with 3 additions and 14 deletions
|
@ -39,7 +39,7 @@ DECAY_STEPS = 160
|
|||
DECAY_RATE = 0.9
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 5
|
||||
EARLY_STOP = 10
|
||||
|
||||
class GarbageCollectorCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
|
@ -152,21 +152,10 @@ def main():
|
|||
verbose=1,
|
||||
save_format="tf"
|
||||
)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Custom garbage collection callback
|
||||
gc_callback = GarbageCollectorCallback()
|
||||
|
||||
# Calculate steps per epoch for training and validation
|
||||
#if MAX_FRAMES <= 0:
|
||||
# average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
||||
#else:
|
||||
# average_frames_per_video = max(MAX_FRAMES, 0)
|
||||
|
||||
#total_frames_train = average_frames_per_video * len(training_videos)
|
||||
#total_frames_validation = average_frames_per_video * len(validation_videos)
|
||||
#steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
||||
#steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
@ -175,7 +164,7 @@ def main():
|
|||
model.fit(
|
||||
training_dataset,
|
||||
epochs=EPOCHS,
|
||||
validation_data=validation_dataset, # Add validation data here
|
||||
validation_data=validation_dataset,
|
||||
callbacks=[early_stop, checkpoint_callback, gc_callback]
|
||||
)
|
||||
LOGGER.info("Model training completed.")
|
||||
|
|
Reference in a new issue