update good
This commit is contained in:
parent
a8acf220b2
commit
9f34cf8074
1 changed files with 3 additions and 14 deletions
|
@ -39,7 +39,7 @@ DECAY_STEPS = 160
|
||||||
DECAY_RATE = 0.9
|
DECAY_RATE = 0.9
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 5
|
EARLY_STOP = 10
|
||||||
|
|
||||||
class GarbageCollectorCallback(Callback):
|
class GarbageCollectorCallback(Callback):
|
||||||
def on_epoch_end(self, epoch, logs=None):
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
@ -152,21 +152,10 @@ def main():
|
||||||
verbose=1,
|
verbose=1,
|
||||||
save_format="tf"
|
save_format="tf"
|
||||||
)
|
)
|
||||||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||||
|
|
||||||
# Custom garbage collection callback
|
# Custom garbage collection callback
|
||||||
gc_callback = GarbageCollectorCallback()
|
gc_callback = GarbageCollectorCallback()
|
||||||
|
|
||||||
# Calculate steps per epoch for training and validation
|
|
||||||
#if MAX_FRAMES <= 0:
|
|
||||||
# average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
|
||||||
#else:
|
|
||||||
# average_frames_per_video = max(MAX_FRAMES, 0)
|
|
||||||
|
|
||||||
#total_frames_train = average_frames_per_video * len(training_videos)
|
|
||||||
#total_frames_validation = average_frames_per_video * len(validation_videos)
|
|
||||||
#steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
|
||||||
#steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
@ -175,7 +164,7 @@ def main():
|
||||||
model.fit(
|
model.fit(
|
||||||
training_dataset,
|
training_dataset,
|
||||||
epochs=EPOCHS,
|
epochs=EPOCHS,
|
||||||
validation_data=validation_dataset, # Add validation data here
|
validation_data=validation_dataset,
|
||||||
callbacks=[early_stop, checkpoint_callback, gc_callback]
|
callbacks=[early_stop, checkpoint_callback, gc_callback]
|
||||||
)
|
)
|
||||||
LOGGER.info("Model training completed.")
|
LOGGER.info("Model training completed.")
|
||||||
|
|
Reference in a new issue