diff --git a/video_compression_model.py b/video_compression_model.py index 9dbbe9d..129e0d0 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -1,5 +1,6 @@ # video_compression_model.py +import gc import os import cv2 import numpy as np @@ -18,6 +19,7 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True): height, width, _ = processed_frame.shape combined = [processed_frame] + if include_controls: crf_array = np.full((height, width, 1), crf) speed_array = np.full((height, width, 1), speed) @@ -27,56 +29,52 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True): def data_generator(videos, batch_size): - # Infinite loop to keep generating batches - while True: - # Iterate over each video - for video_details in videos: - # Get the paths for compressed and original (uncompressed) video files - base_dir = os.path.dirname("test_data/validation/validation.json") - video_path = os.path.join(base_dir, video_details["compressed_video_file"]) - uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"]) - - CRF = scale_crf(video_details["crf"]) - SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])) - - # Open the video files - cap_compressed = cv2.VideoCapture(video_path) - cap_uncompressed = cv2.VideoCapture(uncompressed_video_path) - - # Lists to store the processed frames - compressed_frame_batch = [] # Input data (Target) - uncompressed_frame_batch = [] # Target data (Training) + base_dir = os.path.dirname("test_data/validation/validation.json") - # Read and process frames from both videos - while cap_compressed.isOpened() and cap_uncompressed.isOpened(): + while True: + # Lists to store the processed frames + compressed_frame_batch = [] # Input data (Target) + uncompressed_frame_batch = [] # Target data (Training) + + # Get a list of video capture objects for all videos + caps_compressed = [cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) for video in videos] + caps_uncompressed = [cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) for video in videos] + + # As long as any video can provide frames, keep running + while any(cap.isOpened() for cap in caps_compressed): + for idx, (cap_compressed, cap_uncompressed) in enumerate(zip(caps_compressed, caps_uncompressed)): + #print(f"(Video Change) Processing video {idx}") # Print statement to indicate video change + ret_compressed, compressed_frame = cap_compressed.read() ret_uncompressed, uncompressed_frame = cap_uncompressed.read() + if not ret_compressed or not ret_uncompressed: - break - - # Target data + continue + + CRF = scale_crf(videos[idx]["crf"]) + SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(videos[idx]["preset_speed"])) + compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) - - # Input data uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) - - # Append processed frames to batches + compressed_frame_batch.append(compressed_combined) uncompressed_frame_batch.append(uncompressed_combined) - # If batch is complete, yield it if len(compressed_frame_batch) == batch_size: - yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) # Yielding Training and Target data - compressed_frame_batch = [] - uncompressed_frame_batch = [] + yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) + compressed_frame_batch.clear() + uncompressed_frame_batch.clear() - # Release video files - cap_compressed.release() - cap_uncompressed.release() + # Close all video captures at the end + for cap in caps_compressed + caps_uncompressed: + cap.release() + + cv2.destroyAllWindows() + + # If there are frames left that don't fill a whole batch, send them anyway + if len(compressed_frame_batch) > 0: + yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) - # If there are frames left that don't fill a whole batch, send them anyway - if len(compressed_frame_batch) > 0: - yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) class VideoCompressionModel(tf.keras.Model): def __init__(self): @@ -105,10 +103,5 @@ class VideoCompressionModel(tf.keras.Model): ]) def call(self, inputs): - #print("Input shape:", inputs.shape) - encoded = self.encoder(inputs) - #print("Encoded shape:", encoded.shape) - decoded = self.decoder(encoded) - #print("Decoded shape:", decoded.shape) - return decoded + return self.decoder(self.encoder(inputs))