106 lines
4.9 KiB
Python
106 lines
4.9 KiB
Python
# video_compression_model.py
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
|
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
|
NUM_CHANNELS = 3
|
|
|
|
class VideoDataGenerator(tf.keras.utils.Sequence):
|
|
def __init__(self, video_details_list, batch_size):
|
|
self.video_details_list = video_details_list
|
|
self.batch_size = batch_size
|
|
|
|
def __len__(self):
|
|
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
|
|
|
def __getitem__(self, idx):
|
|
start_idx = idx * self.batch_size
|
|
end_idx = (idx + 1) * self.batch_size
|
|
|
|
batch_data = self.video_details_list[start_idx:end_idx]
|
|
|
|
x1 = np.array([item["frame"] for item in batch_data])
|
|
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
|
x3 = np.array([item["crf"] for item in batch_data])
|
|
x4 = np.array([item["preset_speed"] for item in batch_data])
|
|
|
|
y = x2
|
|
|
|
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
|
return inputs, y
|
|
|
|
|
|
class VideoCompressionModel(tf.keras.Model):
|
|
def __init__(self):
|
|
super(VideoCompressionModel, self).__init__()
|
|
|
|
# Inputs
|
|
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
|
|
self.preset_speed_input = tf.keras.layers.InputLayer(name='preset_speed', input_shape=(1,))
|
|
self.uncompressed_frame_input = tf.keras.layers.InputLayer(name='uncompressed_frame', input_shape=(None, None, NUM_CHANNELS))
|
|
self.compressed_frame_input = tf.keras.layers.InputLayer(name='compressed_frame', input_shape=(None, None, NUM_CHANNELS))
|
|
|
|
# Embedding for speed preset and FC layer for CRF and preset speed
|
|
self.embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16)
|
|
self.fc = tf.keras.layers.Dense(32, activation='relu')
|
|
|
|
# Encoder layers
|
|
self.encoder = tf.keras.Sequential([
|
|
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(None, None, 2 * NUM_CHANNELS + 32)),
|
|
tf.keras.layers.BatchNormalization(),
|
|
tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.BatchNormalization(),
|
|
tf.keras.layers.MaxPooling2D((2, 2)),
|
|
tf.keras.layers.Dropout(0.3)
|
|
])
|
|
|
|
# Decoder layers
|
|
self.decoder = tf.keras.Sequential([
|
|
tf.keras.layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.BatchNormalization(),
|
|
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.BatchNormalization(),
|
|
tf.keras.layers.UpSampling2D((2, 2)),
|
|
tf.keras.layers.Dropout(0.3),
|
|
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
|
])
|
|
|
|
def model_summary(self):
|
|
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
|
|
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
|
|
x3 = tf.keras.Input(shape=(1,), name='crf')
|
|
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
|
|
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
|
|
|
|
|
|
def call(self, inputs):
|
|
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
|
|
|
# Convert frames to float32
|
|
uncompressed_frame = tf.cast(uncompressed_frame, tf.float32)
|
|
compressed_frame = tf.cast(compressed_frame, tf.float32)
|
|
|
|
# Integrate CRF and preset speed into the network
|
|
preset_speed_embedded = self.embedding(preset_speed)
|
|
crf_expanded = tf.expand_dims(crf, -1)
|
|
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)])
|
|
integrated_info = self.fc(integrated_info)
|
|
|
|
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
|
_, height, width, _ = uncompressed_frame.shape
|
|
current_shape = tf.shape(inputs["uncompressed_frame"])
|
|
|
|
height = current_shape[1]
|
|
width = current_shape[2]
|
|
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
|
|
|
|
|
# Merge uncompressed and compressed frames
|
|
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
|
|
|
compressed_representation = self.encoder(frames_merged)
|
|
reconstructed_frame = self.decoder(compressed_representation)
|
|
|
|
return reconstructed_frame
|