diff --git a/train_model.py b/train_model.py index 0e6ce95..2494ebc 100644 --- a/train_model.py +++ b/train_model.py @@ -27,7 +27,7 @@ if gpus: print(e) -from video_compression_model import VideoCompressionModel, data_generator +from video_compression_model import VideoCompressionModel, create_dataset from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER @@ -43,7 +43,7 @@ EARLY_STOP = 5 class GarbageCollectorCallback(Callback): def on_epoch_end(self, epoch, logs=None): - LOGGER.debug(f"Collecting garbage") + LOGGER.debug(f"GC") gc.collect() def save_model(model): @@ -120,6 +120,10 @@ def main(): split_index = int(0.8 * len(all_videos)) training_videos = all_videos[:split_index] validation_videos = all_videos[split_index:] + + training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES) + validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES) + if args.continue_training: model = tf.keras.models.load_model(args.continue_training) @@ -154,26 +158,24 @@ def main(): gc_callback = GarbageCollectorCallback() # Calculate steps per epoch for training and validation - if MAX_FRAMES <= 0: - average_frames_per_video = 2880 # Given 2 minutes @ 24 fps - else: - average_frames_per_video = max(MAX_FRAMES, 0) + #if MAX_FRAMES <= 0: + # average_frames_per_video = 2880 # Given 2 minutes @ 24 fps + #else: + # average_frames_per_video = max(MAX_FRAMES, 0) - total_frames_train = average_frames_per_video * len(training_videos) - total_frames_validation = average_frames_per_video * len(validation_videos) - steps_per_epoch_train = total_frames_train // BATCH_SIZE - steps_per_epoch_validation = total_frames_validation // BATCH_SIZE + #total_frames_train = average_frames_per_video * len(training_videos) + #total_frames_validation = average_frames_per_video * len(validation_videos) + #steps_per_epoch_train = total_frames_train // BATCH_SIZE + #steps_per_epoch_validation = total_frames_validation // BATCH_SIZE gc.collect() # Train the model LOGGER.info("Starting model training.") model.fit( - data_generator(training_videos, BATCH_SIZE), - epochs=EPOCHS, - steps_per_epoch=steps_per_epoch_train, - validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here - validation_steps=steps_per_epoch_validation, # Add validation steps here + training_dataset, + epochs=EPOCHS, + validation_data=validation_dataset, # Add validation data here callbacks=[early_stop, checkpoint_callback, gc_callback] ) LOGGER.info("Model training completed.") diff --git a/video_compression_model.py b/video_compression_model.py index 129e0d0..8a82772 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -28,52 +28,86 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True): return np.concatenate(combined, axis=-1) -def data_generator(videos, batch_size): +def process_video(video): base_dir = os.path.dirname("test_data/validation/validation.json") + + cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) + cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) + compressed_frames = [] + uncompressed_frames = [] + while True: - # Lists to store the processed frames - compressed_frame_batch = [] # Input data (Target) - uncompressed_frame_batch = [] # Target data (Training) + ret_compressed, compressed_frame = cap_compressed.read() + ret_uncompressed, uncompressed_frame = cap_uncompressed.read() - # Get a list of video capture objects for all videos - caps_compressed = [cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) for video in videos] - caps_uncompressed = [cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) for video in videos] + if not ret_compressed or not ret_uncompressed: + break - # As long as any video can provide frames, keep running - while any(cap.isOpened() for cap in caps_compressed): - for idx, (cap_compressed, cap_uncompressed) in enumerate(zip(caps_compressed, caps_uncompressed)): - #print(f"(Video Change) Processing video {idx}") # Print statement to indicate video change - - ret_compressed, compressed_frame = cap_compressed.read() - ret_uncompressed, uncompressed_frame = cap_uncompressed.read() + CRF = scale_crf(video["crf"]) + SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"])) - if not ret_compressed or not ret_uncompressed: - continue + compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) + uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) - CRF = scale_crf(videos[idx]["crf"]) - SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(videos[idx]["preset_speed"])) + compressed_frames.append(compressed_combined) + uncompressed_frames.append(uncompressed_combined) - compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) - uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) + cap_compressed.release() + cap_uncompressed.release() - compressed_frame_batch.append(compressed_combined) - uncompressed_frame_batch.append(uncompressed_combined) + return uncompressed_frames, compressed_frames - if len(compressed_frame_batch) == batch_size: - yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) - compressed_frame_batch.clear() - uncompressed_frame_batch.clear() - # Close all video captures at the end - for cap in caps_compressed + caps_uncompressed: - cap.release() +def frame_generator(videos, max_frames=None): + base_dir = "test_data/validation/" + for video in videos: + cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) + cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) - cv2.destroyAllWindows() + frame_count = 0 + while True: + ret_compressed, compressed_frame = cap_compressed.read() + ret_uncompressed, uncompressed_frame = cap_uncompressed.read() + + if not ret_compressed or not ret_uncompressed: + break + + CRF = scale_crf(video["crf"]) + SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"])) + + compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) + uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) + + yield uncompressed_combined, compressed_combined + + frame_count += 1 + if max_frames is not None and frame_count >= max_frames: + break + + cap_compressed.release() + cap_uncompressed.release() + + + +def create_dataset(videos, batch_size, max_frames=None): + # Determine the output signature by processing a single video to obtain its shape + video_generator_instance = frame_generator(videos, max_frames) + sample_uncompressed, sample_compressed = next(video_generator_instance) + output_signature = ( + tf.TensorSpec(shape=tf.shape(sample_uncompressed), dtype=tf.float32), + tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=tf.float32) + ) + + dataset = tf.data.Dataset.from_generator( + lambda: frame_generator(videos, max_frames), # Include max_frames argument through lambda + output_signature=output_signature + ) + + dataset = dataset.shuffle(100).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) + + return dataset - # If there are frames left that don't fill a whole batch, send them anyway - if len(compressed_frame_batch) > 0: - yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) class VideoCompressionModel(tf.keras.Model):