update
This commit is contained in:
parent
3ea1568ad3
commit
7787d0584e
4 changed files with 35 additions and 25 deletions
|
@ -32,14 +32,19 @@ from video_compression_model import VideoCompressionModel, data_generator
|
|||
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
BATCH_SIZE = 25
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.001
|
||||
DECAY_STEPS = 40
|
||||
LEARNING_RATE = 0.0001
|
||||
DECAY_STEPS = 160
|
||||
DECAY_RATE = 0.9
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 5
|
||||
|
||||
class GarbageCollectorCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
LOGGER.debug(f"Collecting garbage")
|
||||
gc.collect()
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
|
@ -144,6 +149,9 @@ def main():
|
|||
save_format="tf"
|
||||
)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Custom garbage collection callback
|
||||
gc_callback = GarbageCollectorCallback()
|
||||
|
||||
# Calculate steps per epoch for training and validation
|
||||
if MAX_FRAMES <= 0:
|
||||
|
@ -164,7 +172,7 @@ def main():
|
|||
steps_per_epoch=steps_per_epoch_train,
|
||||
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
|
||||
validation_steps=steps_per_epoch_validation, # Add validation steps here
|
||||
callbacks=[early_stop, checkpoint_callback]
|
||||
callbacks=[early_stop, checkpoint_callback, gc_callback]
|
||||
)
|
||||
LOGGER.info("Model training completed.")
|
||||
|
||||
|
|
Reference in a new issue