This commit is contained in:
Jordon Brooks 2023-08-17 01:57:53 +01:00
parent 3ea1568ad3
commit 7787d0584e
4 changed files with 35 additions and 25 deletions

View file

@ -32,14 +32,19 @@ from video_compression_model import VideoCompressionModel, data_generator
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants
BATCH_SIZE = 16
BATCH_SIZE = 25
EPOCHS = 100
LEARNING_RATE = 0.001
DECAY_STEPS = 40
LEARNING_RATE = 0.0001
DECAY_STEPS = 160
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 5
class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"Collecting garbage")
gc.collect()
def save_model(model):
try:
@ -144,6 +149,9 @@ def main():
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()
# Calculate steps per epoch for training and validation
if MAX_FRAMES <= 0:
@ -164,7 +172,7 @@ def main():
steps_per_epoch=steps_per_epoch_train,
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
validation_steps=steps_per_epoch_validation, # Add validation steps here
callbacks=[early_stop, checkpoint_callback]
callbacks=[early_stop, checkpoint_callback, gc_callback]
)
LOGGER.info("Model training completed.")