This commit is contained in:
Jordon Brooks 2023-07-30 13:43:53 +01:00
parent 5bca78e687
commit 9167ff27d4
2 changed files with 46 additions and 31 deletions

View file

@ -4,16 +4,16 @@ import numpy as np
import cv2 import cv2
import argparse import argparse
import tensorflow as tf import tensorflow as tf
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
print("GPUs Detected:", tf.config.list_physical_devices('GPU')) print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
# Constants # Constants
BATCH_SIZE = 16 BATCH_SIZE = 4
EPOCHS = 40 EPOCHS = 100
LEARNING_RATE = 0.00001 LEARNING_RATE = 0.000001
TRAIN_SAMPLES = 100 TRAIN_SAMPLES = 500
MODEL_SAVE_FILE = "models/model.tf" MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints" MODEL_CHECKPOINT_DIR = "checkpoints"
CONTINUE_TRAINING = None CONTINUE_TRAINING = None
@ -23,14 +23,14 @@ def load_list(list_path):
video_details_list = json.load(json_file) video_details_list = json.load(json_file)
return video_details_list return video_details_list
def load_video_from_list(list_path): def load_video_from_list(list_path, samples = TRAIN_SAMPLES):
details_list = load_list(list_path) details_list = load_list(list_path)
all_details = [] all_details = []
num_videos = len(details_list) num_videos = len(details_list)
frames_per_video = int(TRAIN_SAMPLES / num_videos) frames_per_video = int(samples / num_videos)
print(f"Loading {frames_per_video} across {num_videos} videos") print(f"Loading {frames_per_video} frames across {num_videos} videos")
for video_details in details_list: for video_details in details_list:
VIDEO_FILE = video_details["video_file"] VIDEO_FILE = video_details["video_file"]
@ -108,26 +108,10 @@ def main():
print(f"Continue training from: {CONTINUE_TRAINING}") print(f"Continue training from: {CONTINUE_TRAINING}")
all_video_details_train = load_video_from_list("test_data/training.json") all_video_details_train = load_video_from_list("test_data/training.json")
all_video_details_val = load_video_from_list("test_data/validation.json") all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2)
all_train_frames = [video_details["frame"] for video_details in all_video_details_train] train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE)
all_train_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_train] val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE)
all_val_frames = [video_details["frame"] for video_details in all_video_details_val]
all_val_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_val]
all_crf_train = [video_details['crf'] for video_details in all_video_details_train]
all_crf_val = [video_details['crf'] for video_details in all_video_details_val]
all_preset_speed_train = [video_details['preset_speed'] for video_details in all_video_details_train]
all_preset_speed_val = [video_details['preset_speed'] for video_details in all_video_details_val]
# Convert lists to numpy arrays
all_train_frames = np.array(all_train_frames)
all_train_compressed_frames = np.array(all_train_compressed_frames)
all_val_frames = np.array(all_val_frames)
all_val_compressed_frames = np.array(all_val_compressed_frames)
all_crf_train = np.array(all_crf_train)
all_crf_val = np.array(all_crf_val)
all_preset_speed_train = np.array(all_preset_speed_train)
all_preset_speed_val = np.array(all_preset_speed_val)
if CONTINUE_TRAINING: if CONTINUE_TRAINING:
print("loading model:", CONTINUE_TRAINING) print("loading model:", CONTINUE_TRAINING)
@ -154,11 +138,11 @@ def main():
print("\nTraining the model...") print("\nTraining the model...")
model.fit( model.fit(
{"uncompressed_frame": all_train_frames, "compressed_frame": all_train_compressed_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train}, train_generator,
all_train_compressed_frames, # Target is the compressed frame steps_per_epoch=len(train_generator),
batch_size=BATCH_SIZE,
epochs=EPOCHS, epochs=EPOCHS,
validation_data=({"uncompressed_frame": all_val_frames, "compressed_frame": all_val_compressed_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_compressed_frames), validation_data=val_generator,
validation_steps=len(val_generator),
callbacks=[early_stop, checkpoint_callback] callbacks=[early_stop, checkpoint_callback]
) )
print("\nTraining completed!") print("\nTraining completed!")

View file

@ -1,11 +1,37 @@
# video_compression_model.py # video_compression_model.py
import numpy as np
import tensorflow as tf import tensorflow as tf
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
NUM_CHANNELS = 3 NUM_CHANNELS = 3
class VideoDataGenerator(tf.keras.utils.Sequence):
def __init__(self, video_details_list, batch_size):
self.video_details_list = video_details_list
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
def __getitem__(self, idx):
start_idx = idx * self.batch_size
end_idx = (idx + 1) * self.batch_size
batch_data = self.video_details_list[start_idx:end_idx]
x1 = np.array([item["frame"] for item in batch_data])
x2 = np.array([item["compressed_frame"] for item in batch_data])
x3 = np.array([item["crf"] for item in batch_data])
x4 = np.array([item["preset_speed"] for item in batch_data])
y = x2
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
return inputs, y
class VideoCompressionModel(tf.keras.Model): class VideoCompressionModel(tf.keras.Model):
def __init__(self): def __init__(self):
super(VideoCompressionModel, self).__init__() super(VideoCompressionModel, self).__init__()
@ -64,8 +90,13 @@ class VideoCompressionModel(tf.keras.Model):
# Integrate the CRF and preset speed information into the frames as additional channels (features) # Integrate the CRF and preset speed information into the frames as additional channels (features)
_, height, width, _ = uncompressed_frame.shape _, height, width, _ = uncompressed_frame.shape
current_shape = tf.shape(inputs["uncompressed_frame"])
height = current_shape[1]
width = current_shape[2]
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1]) integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
# Merge uncompressed and compressed frames # Merge uncompressed and compressed frames
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated]) frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])