updated
This commit is contained in:
parent
93ccce5ec1
commit
ed5eb91578
6 changed files with 181 additions and 171 deletions
|
@ -1,5 +1,9 @@
|
||||||
# DeepEncode.py
|
# DeepEncode.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
@ -33,12 +37,16 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value):
|
||||||
crf_array = np.array([crf_value])
|
crf_array = np.array([crf_value])
|
||||||
preset_speed_array = np.array([preset_speed_value])
|
preset_speed_array = np.array([preset_speed_value])
|
||||||
|
|
||||||
|
crf_array = np.expand_dims(np.array([crf_value]), axis=-1) # Shape: (1, 1)
|
||||||
|
preset_speed_array = np.expand_dims(np.array([preset_speed_value]), axis=-1) # Shape: (1, 1)
|
||||||
|
|
||||||
|
|
||||||
# Expand dimensions to include batch size
|
# Expand dimensions to include batch size
|
||||||
uncompressed_frame = np.expand_dims(uncompressed_frame, 0)
|
uncompressed_frame = np.expand_dims(uncompressed_frame, 0)
|
||||||
|
|
||||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||||
#cv2.imshow("uncomp", display_frame)
|
#cv2.imshow("uncomp", display_frame)
|
||||||
#cv2.waitKey(10)
|
#cv2.waitKey(0)
|
||||||
|
|
||||||
compressed_frame = model.predict({
|
compressed_frame = model.predict({
|
||||||
"compressed_frame": uncompressed_frame,
|
"compressed_frame": uncompressed_frame,
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
import log
|
import log
|
||||||
|
|
||||||
LOGGER = log.Logger(level="INFO", logfile="training.log", reset_logfile=True)
|
LOGGER = log.Logger(level="DEBUG", logfile="training.log", reset_logfile=True)
|
|
@ -1,73 +1,73 @@
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-ultrafast.mkv",
|
"compressed_video_file": "x264_crf-51_preset-ultrafast.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "ultrafast"
|
"preset_speed": "ultrafast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-16_preset-veryslow.mkv",
|
"compressed_video_file": "x264_crf-16_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 16,
|
"crf": 16,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-18_preset-ultrafast.mkv",
|
"compressed_video_file": "x264_crf-18_preset-ultrafast.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 18,
|
"crf": 18,
|
||||||
"preset_speed": "ultrafast"
|
"preset_speed": "ultrafast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-18_preset-veryslow.mkv",
|
"compressed_video_file": "x264_crf-18_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 18,
|
"crf": 18,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-50_preset-veryslow.mkv",
|
"compressed_video_file": "x264_crf-50_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 50,
|
"crf": 50,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-fast.mkv",
|
"compressed_video_file": "x264_crf-51_preset-fast.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "fast"
|
"preset_speed": "fast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-faster.mkv",
|
"compressed_video_file": "x264_crf-51_preset-faster.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "faster"
|
"preset_speed": "faster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-medium.mkv",
|
"compressed_video_file": "x264_crf-51_preset-medium.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "medium"
|
"preset_speed": "medium"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-slow.mkv",
|
"compressed_video_file": "x264_crf-51_preset-slow.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "slow"
|
"preset_speed": "slow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-slower.mkv",
|
"compressed_video_file": "x264_crf-51_preset-slower.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "slower"
|
"preset_speed": "slower"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-superfast.mkv",
|
"compressed_video_file": "x264_crf-51_preset-superfast.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "superfast"
|
"preset_speed": "superfast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-veryfast.mkv",
|
"compressed_video_file": "x264_crf-51_preset-veryfast.mkv",
|
||||||
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "veryfast"
|
"preset_speed": "veryfast"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
[
|
[
|
||||||
|
|
||||||
{
|
{
|
||||||
"video_file": "Scene2_x264_crf-51_preset-veryslow.mkv",
|
"compressed_video_file": "Scene2_x264_crf-51_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv",
|
"original_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
}
|
}
|
||||||
|
|
128
train_model.py
128
train_model.py
|
@ -1,15 +1,14 @@
|
||||||
# train_model.py
|
# train_model.py
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import argparse
|
import argparse
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
||||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
from global_train import LOGGER
|
from global_train import LOGGER
|
||||||
|
@ -17,13 +16,12 @@ from global_train import LOGGER
|
||||||
# Constants
|
# Constants
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
EPOCHS = 100
|
EPOCHS = 100
|
||||||
LEARNING_RATE = 0.000001
|
LEARNING_RATE = 0.01
|
||||||
TRAIN_SAMPLES = 100
|
TRAIN_SAMPLES = 100
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 10
|
||||||
WIDTH = 638
|
|
||||||
HEIGHT = 360
|
|
||||||
|
|
||||||
def load_video_metadata(list_path):
|
def load_video_metadata(list_path):
|
||||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||||
|
@ -40,92 +38,30 @@ def load_video_metadata(list_path):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
||||||
"""
|
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})")
|
||||||
Load video samples from the metadata list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- list_path (str): Path to the metadata JSON file.
|
|
||||||
- samples (int): Number of total samples to be extracted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- list: Extracted video samples.
|
|
||||||
"""
|
|
||||||
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" )
|
|
||||||
|
|
||||||
details_list = load_video_metadata(list_path)
|
details_list = load_video_metadata(list_path)
|
||||||
all_samples = []
|
all_samples = []
|
||||||
num_videos = len(details_list)
|
num_videos = len(details_list)
|
||||||
frames_per_video = int(samples / num_videos)
|
frames_per_video = math.ceil(samples / num_videos)
|
||||||
|
|
||||||
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
|
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
|
||||||
|
|
||||||
for video_details in details_list:
|
for video_details in details_list:
|
||||||
video_file = video_details["video_file"]
|
compressed_video_file = video_details["compressed_video_file"]
|
||||||
uncompressed_video_file = video_details["uncompressed_video_file"]
|
original_video_file = video_details["original_video_file"]
|
||||||
crf = video_details['crf'] / 63.0
|
crf = video_details['crf'] / 51
|
||||||
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
||||||
video_details['preset_speed'] = preset_speed
|
video_details['preset_speed'] = preset_speed
|
||||||
|
|
||||||
compressed_frames, uncompressed_frames = [], []
|
# Store video details without loading frames
|
||||||
|
all_samples.extend({
|
||||||
try:
|
"frames_per_video": frames_per_video,
|
||||||
cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file))
|
"crf": crf,
|
||||||
cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file))
|
"preset_speed": preset_speed,
|
||||||
|
"compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file),
|
||||||
if not cap.isOpened() or not cap_uncompressed.isOpened():
|
"original_video_file": os.path.join(os.path.dirname(list_path), original_video_file)
|
||||||
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}")
|
} for _ in range(frames_per_video))
|
||||||
|
|
||||||
for _ in range(frames_per_video):
|
|
||||||
ret, frame_compressed = cap.read()
|
|
||||||
ret_uncompressed, frame = cap_uncompressed.read()
|
|
||||||
|
|
||||||
if not ret or not ret_uncompressed:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check frame dimensions and resize if necessary
|
|
||||||
if frame.shape[:2] != (WIDTH, HEIGHT):
|
|
||||||
LOGGER.warn(f"Resizing video: {video_file}")
|
|
||||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
|
||||||
if frame_compressed.shape[:2] != (WIDTH, HEIGHT):
|
|
||||||
LOGGER.warn(f"Resizing video: {uncompressed_video_file}")
|
|
||||||
frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
|
||||||
|
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
uncompressed_frames.append(normalize(frame))
|
|
||||||
compressed_frames.append(normalize(frame_compressed))
|
|
||||||
|
|
||||||
all_samples.extend({
|
|
||||||
"frame": frame,
|
|
||||||
"compressed_frame": frame_compressed,
|
|
||||||
"crf": crf,
|
|
||||||
"preset_speed": preset_speed,
|
|
||||||
"video_file": video_file
|
|
||||||
} for frame, frame_compressed in zip(uncompressed_frames, compressed_frames))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.error(f"Error during video sample loading: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
cap.release()
|
|
||||||
cap_uncompressed.release()
|
|
||||||
|
|
||||||
return all_samples
|
return all_samples
|
||||||
|
|
||||||
def normalize(frame):
|
|
||||||
"""
|
|
||||||
Normalize pixel values of the frame to range [0, 1].
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- frame (ndarray): Image frame.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- ndarray: Normalized frame.
|
|
||||||
"""
|
|
||||||
LOGGER.trace(f"Normalizing frame")
|
|
||||||
return frame / 255.0
|
|
||||||
|
|
||||||
def save_model(model):
|
def save_model(model):
|
||||||
try:
|
try:
|
||||||
|
@ -138,6 +74,7 @@ def save_model(model):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
|
||||||
# Argument parsing
|
# Argument parsing
|
||||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||||
|
@ -147,23 +84,30 @@ def main():
|
||||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
BATCH_SIZE = args.batch_size
|
||||||
|
EPOCHS = args.epochs
|
||||||
|
TRAIN_SAMPLES = args.training_samples
|
||||||
|
LEARNING_RATE = args.learning_rate
|
||||||
|
|
||||||
# Display training configuration
|
# Display training configuration
|
||||||
LOGGER.info("Starting the training with the given configuration.")
|
LOGGER.info("Starting the training with the given configuration.")
|
||||||
LOGGER.info("Training configuration:")
|
LOGGER.info("Training configuration:")
|
||||||
LOGGER.info(f"Batch size: {args.batch_size}")
|
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
||||||
LOGGER.info(f"Epochs: {args.epochs}")
|
LOGGER.info(f"Epochs: {EPOCHS}")
|
||||||
LOGGER.info(f"Training samples: {args.training_samples}")
|
LOGGER.info(f"Training samples: {TRAIN_SAMPLES}")
|
||||||
LOGGER.info(f"Learning rate: {args.learning_rate}")
|
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
||||||
LOGGER.info(f"Continue training from: {args.continue_training}")
|
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||||
|
|
||||||
|
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||||
|
|
||||||
# Load training and validation samples
|
# Load training and validation samples
|
||||||
LOGGER.debug("Loading training and validation samples.")
|
LOGGER.debug("Loading training and validation samples.")
|
||||||
training_samples = load_video_samples("test_data/training/training.json")
|
training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES)
|
||||||
validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2)
|
validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10))
|
||||||
|
|
||||||
train_generator = VideoDataGenerator(training_samples, args.batch_size)
|
train_generator = VideoDataGenerator(training_samples, BATCH_SIZE)
|
||||||
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
|
val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE)
|
||||||
|
|
||||||
# Load or initialize model
|
# Load or initialize model
|
||||||
if args.continue_training:
|
if args.continue_training:
|
||||||
|
@ -172,7 +116,7 @@ def main():
|
||||||
model = VideoCompressionModel()
|
model = VideoCompressionModel()
|
||||||
|
|
||||||
# Set optimizer and compile the model
|
# Set optimizer and compile the model
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
||||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
||||||
|
|
||||||
# Define checkpoints and early stopping
|
# Define checkpoints and early stopping
|
||||||
|
@ -190,7 +134,7 @@ def main():
|
||||||
model.fit(
|
model.fit(
|
||||||
train_generator,
|
train_generator,
|
||||||
steps_per_epoch=len(train_generator),
|
steps_per_epoch=len(train_generator),
|
||||||
epochs=args.epochs,
|
epochs=EPOCHS,
|
||||||
validation_data=val_generator,
|
validation_data=val_generator,
|
||||||
validation_steps=len(val_generator),
|
validation_steps=len(val_generator),
|
||||||
callbacks=[early_stop, checkpoint_callback]
|
callbacks=[early_stop, checkpoint_callback]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# video_compression_model.py
|
# video_compression_model.py
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -8,6 +9,28 @@ from global_train import LOGGER
|
||||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||||
NUM_CHANNELS = 3
|
NUM_CHANNELS = 3
|
||||||
|
WIDTH = 638
|
||||||
|
HEIGHT = 360
|
||||||
|
|
||||||
|
#from tensorflow.keras.mixed_precision import Policy
|
||||||
|
|
||||||
|
#policy = Policy('mixed_float16')
|
||||||
|
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(frame):
|
||||||
|
"""
|
||||||
|
Normalize pixel values of the frame to range [0, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- frame (ndarray): Image frame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ndarray: Normalized frame.
|
||||||
|
"""
|
||||||
|
LOGGER.trace(f"Normalizing frame")
|
||||||
|
return frame / 255.0
|
||||||
|
|
||||||
class VideoDataGenerator(tf.keras.utils.Sequence):
|
class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||||
def __init__(self, video_details_list, batch_size):
|
def __init__(self, video_details_list, batch_size):
|
||||||
|
@ -19,28 +42,59 @@ class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||||
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
try:
|
start_idx = idx * self.batch_size
|
||||||
start_idx = idx * self.batch_size
|
end_idx = (idx + 1) * self.batch_size
|
||||||
end_idx = (idx + 1) * self.batch_size
|
batch_data = self.video_details_list[start_idx:end_idx]
|
||||||
|
|
||||||
batch_data = self.video_details_list[start_idx:end_idx]
|
|
||||||
|
|
||||||
x1 = np.array([item["frame"] for item in batch_data])
|
# Determine the number of videos and frames per video
|
||||||
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
num_videos = len(batch_data)
|
||||||
x3 = np.array([item["crf"] for item in batch_data])
|
frames_per_video = batch_data[0]['frames_per_video'] # Assuming all videos have the same number of frames
|
||||||
x4 = np.array([item["preset_speed"] for item in batch_data])
|
|
||||||
|
|
||||||
y = x2
|
# Pre-allocate arrays for the batch data
|
||||||
|
x1 = np.empty((num_videos * frames_per_video, HEIGHT, WIDTH, NUM_CHANNELS))
|
||||||
|
x2 = np.empty_like(x1)
|
||||||
|
x3 = np.empty((num_videos * frames_per_video, 1))
|
||||||
|
x4 = np.empty_like(x3)
|
||||||
|
|
||||||
|
# Iterate over the videos and frames, filling the pre-allocated arrays
|
||||||
|
for i, item in enumerate(batch_data):
|
||||||
|
compressed_video_file = item["compressed_video_file"]
|
||||||
|
original_video_file = item["original_video_file"]
|
||||||
|
crf = item["crf"]
|
||||||
|
preset_speed = item["preset_speed"]
|
||||||
|
|
||||||
|
cap_compressed = cv2.VideoCapture(compressed_video_file)
|
||||||
|
cap_original = cv2.VideoCapture(original_video_file)
|
||||||
|
for j in range(frames_per_video):
|
||||||
|
compressed_ret, compressed_frame = cap_compressed.read()
|
||||||
|
original_ret, original_frame = cap_original.read()
|
||||||
|
if not compressed_ret or not original_ret:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check frame dimensions and resize if necessary
|
||||||
|
if original_frame.shape[:2] != (WIDTH, HEIGHT):
|
||||||
|
LOGGER.info(f"Resizing video: {original_video_file}")
|
||||||
|
original_frame = cv2.resize(original_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||||
|
if compressed_frame.shape[:2] != (WIDTH, HEIGHT):
|
||||||
|
LOGGER.info(f"Resizing video: {compressed_video_file}")
|
||||||
|
compressed_frame = cv2.resize(compressed_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
original_frame = cv2.cvtColor(original_frame, cv2.COLOR_BGR2RGB)
|
||||||
|
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# Store the processed frames and metadata directly in the pre-allocated arrays
|
||||||
|
x1[i * frames_per_video + j] = normalize(original_frame)
|
||||||
|
x2[i * frames_per_video + j] = normalize(compressed_frame)
|
||||||
|
x3[i * frames_per_video + j] = crf
|
||||||
|
x4[i * frames_per_video + j] = preset_speed
|
||||||
|
|
||||||
|
cap_original.release()
|
||||||
|
cap_compressed.release()
|
||||||
|
|
||||||
|
y = x2
|
||||||
|
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||||
|
return inputs, y
|
||||||
|
|
||||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
|
||||||
return inputs, y
|
|
||||||
|
|
||||||
except IndexError:
|
|
||||||
LOGGER.error(f"Index {idx} out of bounds in VideoDataGenerator.")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.error(f"Unexpected error in VideoDataGenerator: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class VideoCompressionModel(tf.keras.Model):
|
class VideoCompressionModel(tf.keras.Model):
|
||||||
|
@ -78,7 +132,43 @@ class VideoCompressionModel(tf.keras.Model):
|
||||||
tf.keras.layers.Dropout(0.3),
|
tf.keras.layers.Dropout(0.3),
|
||||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
LOGGER.trace("Calling VideoCompressionModel.")
|
||||||
|
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||||
|
|
||||||
|
# Convert frames to float32
|
||||||
|
uncompressed_frame = tf.cast(uncompressed_frame, tf.float16)
|
||||||
|
compressed_frame = tf.cast(compressed_frame, tf.float16)
|
||||||
|
|
||||||
|
# Embedding for preset speed
|
||||||
|
preset_speed_embedded = self.embedding(preset_speed)
|
||||||
|
preset_speed_embedded = tf.keras.layers.Flatten()(preset_speed_embedded)
|
||||||
|
|
||||||
|
# Reshaping CRF to match the shape of preset_speed_embedded
|
||||||
|
crf_expanded = tf.keras.layers.Flatten()(tf.repeat(crf, 16, axis=-1))
|
||||||
|
|
||||||
|
|
||||||
|
# Concatenating the CRF and preset speed information
|
||||||
|
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, preset_speed_embedded])
|
||||||
|
integrated_info = self.fc(integrated_info)
|
||||||
|
|
||||||
|
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
||||||
|
_, height, width, _ = uncompressed_frame.shape
|
||||||
|
current_shape = tf.shape(inputs["uncompressed_frame"])
|
||||||
|
|
||||||
|
height = current_shape[1]
|
||||||
|
width = current_shape[2]
|
||||||
|
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
||||||
|
|
||||||
|
# Merge uncompressed and compressed frames
|
||||||
|
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
||||||
|
|
||||||
|
compressed_representation = self.encoder(frames_merged)
|
||||||
|
reconstructed_frame = self.decoder(compressed_representation)
|
||||||
|
|
||||||
|
return reconstructed_frame
|
||||||
|
|
||||||
def model_summary(self):
|
def model_summary(self):
|
||||||
try:
|
try:
|
||||||
LOGGER.info("Generating model summary.")
|
LOGGER.info("Generating model summary.")
|
||||||
|
@ -90,34 +180,3 @@ class VideoCompressionModel(tf.keras.Model):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.error(f"Unexpected error during model summary generation: {e}")
|
LOGGER.error(f"Unexpected error during model summary generation: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def call(self, inputs):
|
|
||||||
LOGGER.trace("Calling VideoCompressionModel.")
|
|
||||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
|
||||||
|
|
||||||
# Convert frames to float32
|
|
||||||
uncompressed_frame = tf.cast(uncompressed_frame, tf.float32)
|
|
||||||
compressed_frame = tf.cast(compressed_frame, tf.float32)
|
|
||||||
|
|
||||||
# Integrate CRF and preset speed into the network
|
|
||||||
preset_speed_embedded = self.embedding(preset_speed)
|
|
||||||
crf_expanded = tf.expand_dims(crf, -1)
|
|
||||||
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)])
|
|
||||||
integrated_info = self.fc(integrated_info)
|
|
||||||
|
|
||||||
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
|
||||||
_, height, width, _ = uncompressed_frame.shape
|
|
||||||
current_shape = tf.shape(inputs["uncompressed_frame"])
|
|
||||||
|
|
||||||
height = current_shape[1]
|
|
||||||
width = current_shape[2]
|
|
||||||
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
|
||||||
|
|
||||||
|
|
||||||
# Merge uncompressed and compressed frames
|
|
||||||
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
|
||||||
|
|
||||||
compressed_representation = self.encoder(frames_merged)
|
|
||||||
reconstructed_frame = self.decoder(compressed_representation)
|
|
||||||
|
|
||||||
return reconstructed_frame
|
|
||||||
|
|
Reference in a new issue