model now uses tensorflow dataset generator

This commit is contained in:
Jordon Brooks 2023-08-18 00:42:17 +01:00
parent ba6c132c67
commit f06d3ae504
2 changed files with 84 additions and 48 deletions

View file

@ -27,7 +27,7 @@ if gpus:
print(e)
from video_compression_model import VideoCompressionModel, data_generator
from video_compression_model import VideoCompressionModel, create_dataset
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
@ -43,7 +43,7 @@ EARLY_STOP = 5
class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"Collecting garbage")
LOGGER.debug(f"GC")
gc.collect()
def save_model(model):
@ -120,6 +120,10 @@ def main():
split_index = int(0.8 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
if args.continue_training:
model = tf.keras.models.load_model(args.continue_training)
@ -154,26 +158,24 @@ def main():
gc_callback = GarbageCollectorCallback()
# Calculate steps per epoch for training and validation
if MAX_FRAMES <= 0:
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
else:
average_frames_per_video = max(MAX_FRAMES, 0)
#if MAX_FRAMES <= 0:
# average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
#else:
# average_frames_per_video = max(MAX_FRAMES, 0)
total_frames_train = average_frames_per_video * len(training_videos)
total_frames_validation = average_frames_per_video * len(validation_videos)
steps_per_epoch_train = total_frames_train // BATCH_SIZE
steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
#total_frames_train = average_frames_per_video * len(training_videos)
#total_frames_validation = average_frames_per_video * len(validation_videos)
#steps_per_epoch_train = total_frames_train // BATCH_SIZE
#steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
gc.collect()
# Train the model
LOGGER.info("Starting model training.")
model.fit(
data_generator(training_videos, BATCH_SIZE),
epochs=EPOCHS,
steps_per_epoch=steps_per_epoch_train,
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
validation_steps=steps_per_epoch_validation, # Add validation steps here
training_dataset,
epochs=EPOCHS,
validation_data=validation_dataset, # Add validation data here
callbacks=[early_stop, checkpoint_callback, gc_callback]
)
LOGGER.info("Model training completed.")