import tensorflow as tf PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_FRAMES = 5 # Number of consecutive frames in a sequence NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels) #policy = tf.keras.mixed_precision.Policy('mixed_float16') #tf.keras.mixed_precision.set_global_policy(policy) class VideoCompressionModel(tf.keras.Model): def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5, regularization_factor=1e-4): super(VideoCompressionModel, self).__init__() self.NUM_CHANNELS = NUM_CHANNELS self.NUM_FRAMES = NUM_FRAMES # Regularization self.regularizer = tf.keras.regularizers.l2(regularization_factor) # Embedding layer for preset_speed self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16, embeddings_regularizer=self.regularizer) # Encoder layers self.encoder = tf.keras.Sequential([ tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16), kernel_regularizer=self.regularizer), tf.keras.layers.MaxPooling3D((2, 2, 2)), # Add more encoder layers as needed ]) # Decoder layers self.decoder = tf.keras.Sequential([ tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer), tf.keras.layers.UpSampling3D((2, 2, 2)), # Add more decoder layers as needed tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer) # Output layer for video frames ]) def call(self, inputs): frames = inputs["frames"] crf = tf.expand_dims(inputs["crf"], -1) preset_speed = inputs["preset_speed"] # Convert preset_speed to embeddings preset_embedding = self.preset_embedding(preset_speed) preset_embedding = tf.keras.layers.Flatten()(preset_embedding) # Concatenate crf and preset_embedding to frames frames_shape = tf.shape(frames) repeated_crf = tf.tile(tf.reshape(crf, (-1, 1, 1, 1, 1)), [1, frames_shape[1], frames_shape[2], frames_shape[3], 1]) repeated_preset = tf.tile(tf.reshape(preset_embedding, (-1, 1, 1, 1, 16)), [1, frames_shape[1], frames_shape[2], frames_shape[3], 1]) frames = tf.concat([frames, repeated_crf, repeated_preset], axis=-1) # Encoding the video frames compressed_representation = self.encoder(frames) # Decoding to generate compressed video frames reconstructed_frames = self.decoder(compressed_representation) return reconstructed_frames[:,-1,:,:,:]