This commit is contained in:
Jordon Brooks 2023-07-30 13:43:53 +01:00
parent 5bca78e687
commit 9167ff27d4
2 changed files with 46 additions and 31 deletions

View file

@ -1,11 +1,37 @@
# video_compression_model.py
import numpy as np
import tensorflow as tf
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
NUM_CHANNELS = 3
class VideoDataGenerator(tf.keras.utils.Sequence):
def __init__(self, video_details_list, batch_size):
self.video_details_list = video_details_list
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
def __getitem__(self, idx):
start_idx = idx * self.batch_size
end_idx = (idx + 1) * self.batch_size
batch_data = self.video_details_list[start_idx:end_idx]
x1 = np.array([item["frame"] for item in batch_data])
x2 = np.array([item["compressed_frame"] for item in batch_data])
x3 = np.array([item["crf"] for item in batch_data])
x4 = np.array([item["preset_speed"] for item in batch_data])
y = x2
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
return inputs, y
class VideoCompressionModel(tf.keras.Model):
def __init__(self):
super(VideoCompressionModel, self).__init__()
@ -64,8 +90,13 @@ class VideoCompressionModel(tf.keras.Model):
# Integrate the CRF and preset speed information into the frames as additional channels (features)
_, height, width, _ = uncompressed_frame.shape
current_shape = tf.shape(inputs["uncompressed_frame"])
height = current_shape[1]
width = current_shape[2]
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
# Merge uncompressed and compressed frames
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])