# video_compression_model.py import tensorflow as tf PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_CHANNELS = 3 class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() # Inputs self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,)) self.preset_speed_input = tf.keras.layers.InputLayer(name='preset_speed', input_shape=(1,)) self.uncompressed_frame_input = tf.keras.layers.InputLayer(name='uncompressed_frame', input_shape=(None, None, NUM_CHANNELS)) self.compressed_frame_input = tf.keras.layers.InputLayer(name='compressed_frame', input_shape=(None, None, NUM_CHANNELS)) # Embedding for speed preset and FC layer for CRF and preset speed self.embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16) self.fc = tf.keras.layers.Dense(32, activation='relu') # Encoder layers self.encoder = tf.keras.Sequential([ tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(None, None, 2 * NUM_CHANNELS + 32)), tf.keras.layers.BatchNormalization(), tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'), tf.keras.layers.BatchNormalization(), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Dropout(0.3) ]) # Decoder layers self.decoder = tf.keras.Sequential([ tf.keras.layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), tf.keras.layers.BatchNormalization(), tf.keras.layers.UpSampling2D((2, 2)), tf.keras.layers.Dropout(0.3), tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames ]) def call(self, inputs): uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed'] # Convert frames to float32 uncompressed_frame = tf.cast(uncompressed_frame, tf.float32) compressed_frame = tf.cast(compressed_frame, tf.float32) # Integrate CRF and preset speed into the network preset_speed_embedded = self.embedding(preset_speed) crf_expanded = tf.expand_dims(crf, -1) integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)]) integrated_info = self.fc(integrated_info) # Integrate the CRF and preset speed information into the frames as additional channels (features) _, height, width, _ = uncompressed_frame.shape integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1]) # Merge uncompressed and compressed frames frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated]) compressed_representation = self.encoder(frames_merged) reconstructed_frame = self.decoder(compressed_representation) return reconstructed_frame