# video_compression_model.py import os import cv2 import numpy as np import tensorflow as tf from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH #from tensorflow.keras.mixed_precision import Policy #policy = Policy('mixed_float16') #tf.keras.mixed_precision.set_global_policy(policy) def combine_batch(frame, crf, speed, include_controls=True, resize=True): processed_frame = preprocess_frame(frame, resize) height, width, _ = processed_frame.shape combined = [processed_frame] if include_controls: crf_array = np.full((height, width, 1), crf) speed_array = np.full((height, width, 1), speed) combined.extend([crf_array, speed_array]) return np.concatenate(combined, axis=-1) def data_generator(videos, batch_size): # Infinite loop to keep generating batches while True: # Iterate over each video for video_details in videos: # Get the paths for compressed and original (uncompressed) video files base_dir = os.path.dirname("test_data/validation/validation.json") video_path = os.path.join(base_dir, video_details["compressed_video_file"]) uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"]) CRF = scale_crf(video_details["crf"]) SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])) # Open the video files cap_compressed = cv2.VideoCapture(video_path) cap_uncompressed = cv2.VideoCapture(uncompressed_video_path) # Lists to store the processed frames compressed_frame_batch = [] # Input data (Target) uncompressed_frame_batch = [] # Target data (Training) # Read and process frames from both videos while cap_compressed.isOpened() and cap_uncompressed.isOpened(): ret_compressed, compressed_frame = cap_compressed.read() ret_uncompressed, uncompressed_frame = cap_uncompressed.read() if not ret_compressed or not ret_uncompressed: break # Target data compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) # Input data uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) # Append processed frames to batches compressed_frame_batch.append(compressed_combined) uncompressed_frame_batch.append(uncompressed_combined) # If batch is complete, yield it if len(compressed_frame_batch) == batch_size: yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) # Yielding Training and Target data compressed_frame_batch = [] uncompressed_frame_batch = [] # Release video files cap_compressed.release() cap_uncompressed.release() # If there are frames left that don't fill a whole batch, send them anyway if len(compressed_frame_batch) > 0: yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() LOGGER.debug("Initializing VideoCompressionModel.") # Input shape (includes channels for CRF and SPEED_PRESET) input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2) # Encoder part of the model self.encoder = tf.keras.Sequential([ tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram), tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'), tf.keras.layers.MaxPooling2D((2, 2), padding='same'), tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'), tf.keras.layers.MaxPooling2D((2, 2), padding='same') ]) # Decoder part of the model self.decoder = tf.keras.Sequential([ tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'), tf.keras.layers.UpSampling2D((2, 2)), tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), tf.keras.layers.UpSampling2D((2, 2)), tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS, (3, 3), activation='sigmoid', padding='same') ]) def call(self, inputs): #print("Input shape:", inputs.shape) encoded = self.encoder(inputs) #print("Encoded shape:", encoded.shape) decoded = self.decoder(encoded) #print("Decoded shape:", decoded.shape) return decoded