Improved model
This commit is contained in:
parent
9167ff27d4
commit
60c6c97071
8 changed files with 327 additions and 112 deletions
|
@ -3,12 +3,15 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from global_train import LOGGER
|
||||
|
||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_CHANNELS = 3
|
||||
|
||||
class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||
def __init__(self, video_details_list, batch_size):
|
||||
LOGGER.debug("Initializing VideoDataGenerator with batch size: {}".format(batch_size))
|
||||
self.video_details_list = video_details_list
|
||||
self.batch_size = batch_size
|
||||
|
||||
|
@ -16,25 +19,34 @@ class VideoDataGenerator(tf.keras.utils.Sequence):
|
|||
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start_idx = idx * self.batch_size
|
||||
end_idx = (idx + 1) * self.batch_size
|
||||
try:
|
||||
start_idx = idx * self.batch_size
|
||||
end_idx = (idx + 1) * self.batch_size
|
||||
|
||||
batch_data = self.video_details_list[start_idx:end_idx]
|
||||
|
||||
x1 = np.array([item["frame"] for item in batch_data])
|
||||
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
||||
x3 = np.array([item["crf"] for item in batch_data])
|
||||
x4 = np.array([item["preset_speed"] for item in batch_data])
|
||||
|
||||
y = x2
|
||||
|
||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||
return inputs, y
|
||||
|
||||
batch_data = self.video_details_list[start_idx:end_idx]
|
||||
|
||||
x1 = np.array([item["frame"] for item in batch_data])
|
||||
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
||||
x3 = np.array([item["crf"] for item in batch_data])
|
||||
x4 = np.array([item["preset_speed"] for item in batch_data])
|
||||
|
||||
y = x2
|
||||
|
||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||
return inputs, y
|
||||
except IndexError:
|
||||
LOGGER.error(f"Index {idx} out of bounds in VideoDataGenerator.")
|
||||
raise
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error in VideoDataGenerator: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
LOGGER.debug("Initializing VideoCompressionModel.")
|
||||
|
||||
# Inputs
|
||||
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
|
||||
|
@ -68,14 +80,19 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
])
|
||||
|
||||
def model_summary(self):
|
||||
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
|
||||
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
|
||||
x3 = tf.keras.Input(shape=(1,), name='crf')
|
||||
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
|
||||
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
|
||||
|
||||
try:
|
||||
LOGGER.info("Generating model summary.")
|
||||
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
|
||||
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
|
||||
x3 = tf.keras.Input(shape=(1,), name='crf')
|
||||
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
|
||||
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error during model summary generation: {e}")
|
||||
raise
|
||||
|
||||
def call(self, inputs):
|
||||
LOGGER.trace("Calling VideoCompressionModel.")
|
||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||
|
||||
# Convert frames to float32
|
||||
|
|
Reference in a new issue