updated
This commit is contained in:
parent
93ccce5ec1
commit
ed5eb91578
6 changed files with 181 additions and 171 deletions
|
@ -1,5 +1,6 @@
|
|||
# video_compression_model.py
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
@ -8,6 +9,28 @@ from global_train import LOGGER
|
|||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_CHANNELS = 3
|
||||
WIDTH = 638
|
||||
HEIGHT = 360
|
||||
|
||||
#from tensorflow.keras.mixed_precision import Policy
|
||||
|
||||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
|
||||
|
||||
def normalize(frame):
|
||||
"""
|
||||
Normalize pixel values of the frame to range [0, 1].
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
|
||||
Returns:
|
||||
- ndarray: Normalized frame.
|
||||
"""
|
||||
LOGGER.trace(f"Normalizing frame")
|
||||
return frame / 255.0
|
||||
|
||||
class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||
def __init__(self, video_details_list, batch_size):
|
||||
|
@ -19,28 +42,59 @@ class VideoDataGenerator(tf.keras.utils.Sequence):
|
|||
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
start_idx = idx * self.batch_size
|
||||
end_idx = (idx + 1) * self.batch_size
|
||||
|
||||
batch_data = self.video_details_list[start_idx:end_idx]
|
||||
start_idx = idx * self.batch_size
|
||||
end_idx = (idx + 1) * self.batch_size
|
||||
batch_data = self.video_details_list[start_idx:end_idx]
|
||||
|
||||
x1 = np.array([item["frame"] for item in batch_data])
|
||||
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
||||
x3 = np.array([item["crf"] for item in batch_data])
|
||||
x4 = np.array([item["preset_speed"] for item in batch_data])
|
||||
# Determine the number of videos and frames per video
|
||||
num_videos = len(batch_data)
|
||||
frames_per_video = batch_data[0]['frames_per_video'] # Assuming all videos have the same number of frames
|
||||
|
||||
y = x2
|
||||
# Pre-allocate arrays for the batch data
|
||||
x1 = np.empty((num_videos * frames_per_video, HEIGHT, WIDTH, NUM_CHANNELS))
|
||||
x2 = np.empty_like(x1)
|
||||
x3 = np.empty((num_videos * frames_per_video, 1))
|
||||
x4 = np.empty_like(x3)
|
||||
|
||||
# Iterate over the videos and frames, filling the pre-allocated arrays
|
||||
for i, item in enumerate(batch_data):
|
||||
compressed_video_file = item["compressed_video_file"]
|
||||
original_video_file = item["original_video_file"]
|
||||
crf = item["crf"]
|
||||
preset_speed = item["preset_speed"]
|
||||
|
||||
cap_compressed = cv2.VideoCapture(compressed_video_file)
|
||||
cap_original = cv2.VideoCapture(original_video_file)
|
||||
for j in range(frames_per_video):
|
||||
compressed_ret, compressed_frame = cap_compressed.read()
|
||||
original_ret, original_frame = cap_original.read()
|
||||
if not compressed_ret or not original_ret:
|
||||
continue
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if original_frame.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.info(f"Resizing video: {original_video_file}")
|
||||
original_frame = cv2.resize(original_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
if compressed_frame.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.info(f"Resizing video: {compressed_video_file}")
|
||||
compressed_frame = cv2.resize(compressed_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
|
||||
original_frame = cv2.cvtColor(original_frame, cv2.COLOR_BGR2RGB)
|
||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Store the processed frames and metadata directly in the pre-allocated arrays
|
||||
x1[i * frames_per_video + j] = normalize(original_frame)
|
||||
x2[i * frames_per_video + j] = normalize(compressed_frame)
|
||||
x3[i * frames_per_video + j] = crf
|
||||
x4[i * frames_per_video + j] = preset_speed
|
||||
|
||||
cap_original.release()
|
||||
cap_compressed.release()
|
||||
|
||||
y = x2
|
||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||
return inputs, y
|
||||
|
||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||
return inputs, y
|
||||
|
||||
except IndexError:
|
||||
LOGGER.error(f"Index {idx} out of bounds in VideoDataGenerator.")
|
||||
raise
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error in VideoDataGenerator: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
|
@ -78,7 +132,43 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
tf.keras.layers.Dropout(0.3),
|
||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
LOGGER.trace("Calling VideoCompressionModel.")
|
||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||
|
||||
# Convert frames to float32
|
||||
uncompressed_frame = tf.cast(uncompressed_frame, tf.float16)
|
||||
compressed_frame = tf.cast(compressed_frame, tf.float16)
|
||||
|
||||
# Embedding for preset speed
|
||||
preset_speed_embedded = self.embedding(preset_speed)
|
||||
preset_speed_embedded = tf.keras.layers.Flatten()(preset_speed_embedded)
|
||||
|
||||
# Reshaping CRF to match the shape of preset_speed_embedded
|
||||
crf_expanded = tf.keras.layers.Flatten()(tf.repeat(crf, 16, axis=-1))
|
||||
|
||||
|
||||
# Concatenating the CRF and preset speed information
|
||||
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, preset_speed_embedded])
|
||||
integrated_info = self.fc(integrated_info)
|
||||
|
||||
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
||||
_, height, width, _ = uncompressed_frame.shape
|
||||
current_shape = tf.shape(inputs["uncompressed_frame"])
|
||||
|
||||
height = current_shape[1]
|
||||
width = current_shape[2]
|
||||
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
||||
|
||||
# Merge uncompressed and compressed frames
|
||||
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
||||
|
||||
compressed_representation = self.encoder(frames_merged)
|
||||
reconstructed_frame = self.decoder(compressed_representation)
|
||||
|
||||
return reconstructed_frame
|
||||
|
||||
def model_summary(self):
|
||||
try:
|
||||
LOGGER.info("Generating model summary.")
|
||||
|
@ -90,34 +180,3 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error during model summary generation: {e}")
|
||||
raise
|
||||
|
||||
def call(self, inputs):
|
||||
LOGGER.trace("Calling VideoCompressionModel.")
|
||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||
|
||||
# Convert frames to float32
|
||||
uncompressed_frame = tf.cast(uncompressed_frame, tf.float32)
|
||||
compressed_frame = tf.cast(compressed_frame, tf.float32)
|
||||
|
||||
# Integrate CRF and preset speed into the network
|
||||
preset_speed_embedded = self.embedding(preset_speed)
|
||||
crf_expanded = tf.expand_dims(crf, -1)
|
||||
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)])
|
||||
integrated_info = self.fc(integrated_info)
|
||||
|
||||
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
||||
_, height, width, _ = uncompressed_frame.shape
|
||||
current_shape = tf.shape(inputs["uncompressed_frame"])
|
||||
|
||||
height = current_shape[1]
|
||||
width = current_shape[2]
|
||||
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
||||
|
||||
|
||||
# Merge uncompressed and compressed frames
|
||||
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
||||
|
||||
compressed_representation = self.encoder(frames_merged)
|
||||
reconstructed_frame = self.decoder(compressed_representation)
|
||||
|
||||
return reconstructed_frame
|
||||
|
|
Reference in a new issue