diff --git a/train_model.py b/train_model.py index 388d5c1..d130a47 100644 --- a/train_model.py +++ b/train_model.py @@ -4,16 +4,16 @@ import numpy as np import cv2 import argparse import tensorflow as tf -from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES +from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint print("GPUs Detected:", tf.config.list_physical_devices('GPU')) # Constants -BATCH_SIZE = 16 -EPOCHS = 40 -LEARNING_RATE = 0.00001 -TRAIN_SAMPLES = 100 +BATCH_SIZE = 4 +EPOCHS = 100 +LEARNING_RATE = 0.000001 +TRAIN_SAMPLES = 500 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" CONTINUE_TRAINING = None @@ -23,14 +23,14 @@ def load_list(list_path): video_details_list = json.load(json_file) return video_details_list -def load_video_from_list(list_path): +def load_video_from_list(list_path, samples = TRAIN_SAMPLES): details_list = load_list(list_path) all_details = [] num_videos = len(details_list) - frames_per_video = int(TRAIN_SAMPLES / num_videos) + frames_per_video = int(samples / num_videos) - print(f"Loading {frames_per_video} across {num_videos} videos") + print(f"Loading {frames_per_video} frames across {num_videos} videos") for video_details in details_list: VIDEO_FILE = video_details["video_file"] @@ -108,26 +108,10 @@ def main(): print(f"Continue training from: {CONTINUE_TRAINING}") all_video_details_train = load_video_from_list("test_data/training.json") - all_video_details_val = load_video_from_list("test_data/validation.json") + all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2) - all_train_frames = [video_details["frame"] for video_details in all_video_details_train] - all_train_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_train] - all_val_frames = [video_details["frame"] for video_details in all_video_details_val] - all_val_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_val] - all_crf_train = [video_details['crf'] for video_details in all_video_details_train] - all_crf_val = [video_details['crf'] for video_details in all_video_details_val] - all_preset_speed_train = [video_details['preset_speed'] for video_details in all_video_details_train] - all_preset_speed_val = [video_details['preset_speed'] for video_details in all_video_details_val] - - # Convert lists to numpy arrays - all_train_frames = np.array(all_train_frames) - all_train_compressed_frames = np.array(all_train_compressed_frames) - all_val_frames = np.array(all_val_frames) - all_val_compressed_frames = np.array(all_val_compressed_frames) - all_crf_train = np.array(all_crf_train) - all_crf_val = np.array(all_crf_val) - all_preset_speed_train = np.array(all_preset_speed_train) - all_preset_speed_val = np.array(all_preset_speed_val) + train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE) + val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE) if CONTINUE_TRAINING: print("loading model:", CONTINUE_TRAINING) @@ -154,11 +138,11 @@ def main(): print("\nTraining the model...") model.fit( - {"uncompressed_frame": all_train_frames, "compressed_frame": all_train_compressed_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train}, - all_train_compressed_frames, # Target is the compressed frame - batch_size=BATCH_SIZE, + train_generator, + steps_per_epoch=len(train_generator), epochs=EPOCHS, - validation_data=({"uncompressed_frame": all_val_frames, "compressed_frame": all_val_compressed_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_compressed_frames), + validation_data=val_generator, + validation_steps=len(val_generator), callbacks=[early_stop, checkpoint_callback] ) print("\nTraining completed!") diff --git a/video_compression_model.py b/video_compression_model.py index 795be62..fa22a4c 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -1,11 +1,37 @@ # video_compression_model.py +import numpy as np import tensorflow as tf PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_CHANNELS = 3 +class VideoDataGenerator(tf.keras.utils.Sequence): + def __init__(self, video_details_list, batch_size): + self.video_details_list = video_details_list + self.batch_size = batch_size + + def __len__(self): + return int(np.ceil(len(self.video_details_list) / float(self.batch_size))) + + def __getitem__(self, idx): + start_idx = idx * self.batch_size + end_idx = (idx + 1) * self.batch_size + + batch_data = self.video_details_list[start_idx:end_idx] + + x1 = np.array([item["frame"] for item in batch_data]) + x2 = np.array([item["compressed_frame"] for item in batch_data]) + x3 = np.array([item["crf"] for item in batch_data]) + x4 = np.array([item["preset_speed"] for item in batch_data]) + + y = x2 + + inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4} + return inputs, y + + class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() @@ -64,8 +90,13 @@ class VideoCompressionModel(tf.keras.Model): # Integrate the CRF and preset speed information into the frames as additional channels (features) _, height, width, _ = uncompressed_frame.shape + current_shape = tf.shape(inputs["uncompressed_frame"]) + + height = current_shape[1] + width = current_shape[2] integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1]) + # Merge uncompressed and compressed frames frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])