diff --git a/DeepEncode.py b/DeepEncode.py index 3d77d11..d63ed00 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -1,5 +1,9 @@ # DeepEncode.py +import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' + import tensorflow as tf import numpy as np import cv2 @@ -33,12 +37,16 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value): crf_array = np.array([crf_value]) preset_speed_array = np.array([preset_speed_value]) + crf_array = np.expand_dims(np.array([crf_value]), axis=-1) # Shape: (1, 1) + preset_speed_array = np.expand_dims(np.array([preset_speed_value]), axis=-1) # Shape: (1, 1) + + # Expand dimensions to include batch size uncompressed_frame = np.expand_dims(uncompressed_frame, 0) #display_frame = np.clip(cv2.cvtColor(uncompressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) #cv2.imshow("uncomp", display_frame) - #cv2.waitKey(10) + #cv2.waitKey(0) compressed_frame = model.predict({ "compressed_frame": uncompressed_frame, diff --git a/global_train.py b/global_train.py index 5dc64c4..b384df3 100644 --- a/global_train.py +++ b/global_train.py @@ -1,3 +1,3 @@ import log -LOGGER = log.Logger(level="INFO", logfile="training.log", reset_logfile=True) \ No newline at end of file +LOGGER = log.Logger(level="DEBUG", logfile="training.log", reset_logfile=True) \ No newline at end of file diff --git a/test_data/training/training.json b/test_data/training/training.json index 90a5622..1453835 100644 --- a/test_data/training/training.json +++ b/test_data/training/training.json @@ -1,73 +1,73 @@ [ { - "video_file": "x264_crf-51_preset-ultrafast.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-ultrafast.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "ultrafast" }, { - "video_file": "x264_crf-16_preset-veryslow.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-16_preset-veryslow.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 16, "preset_speed": "veryslow" }, { - "video_file": "x264_crf-18_preset-ultrafast.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-18_preset-ultrafast.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 18, "preset_speed": "ultrafast" }, { - "video_file": "x264_crf-18_preset-veryslow.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-18_preset-veryslow.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 18, "preset_speed": "veryslow" }, { - "video_file": "x264_crf-50_preset-veryslow.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-50_preset-veryslow.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 50, "preset_speed": "veryslow" }, { - "video_file": "x264_crf-51_preset-fast.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-fast.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "fast" }, { - "video_file": "x264_crf-51_preset-faster.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-faster.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "faster" }, { - "video_file": "x264_crf-51_preset-medium.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-medium.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "medium" }, { - "video_file": "x264_crf-51_preset-slow.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-slow.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "slow" }, { - "video_file": "x264_crf-51_preset-slower.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-slower.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "slower" }, { - "video_file": "x264_crf-51_preset-superfast.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-superfast.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "superfast" }, { - "video_file": "x264_crf-51_preset-veryfast.mkv", - "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "x264_crf-51_preset-veryfast.mkv", + "original_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "veryfast" } diff --git a/test_data/validation/validation.json b/test_data/validation/validation.json index 7f938f2..55a58e7 100644 --- a/test_data/validation/validation.json +++ b/test_data/validation/validation.json @@ -1,8 +1,7 @@ [ - { - "video_file": "Scene2_x264_crf-51_preset-veryslow.mkv", - "uncompressed_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv", + "compressed_video_file": "Scene2_x264_crf-51_preset-veryslow.mkv", + "original_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "veryslow" } diff --git a/train_model.py b/train_model.py index 6a92fd7..9fe8e11 100644 --- a/train_model.py +++ b/train_model.py @@ -1,15 +1,14 @@ # train_model.py +import math import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import json -import numpy as np -import cv2 import argparse import tensorflow as tf -from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator +from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from global_train import LOGGER @@ -17,13 +16,12 @@ from global_train import LOGGER # Constants BATCH_SIZE = 4 EPOCHS = 100 -LEARNING_RATE = 0.000001 +LEARNING_RATE = 0.01 TRAIN_SAMPLES = 100 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 -WIDTH = 638 -HEIGHT = 360 + def load_video_metadata(list_path): LOGGER.trace(f"Entering: load_video_metadata({list_path})") @@ -40,92 +38,30 @@ def load_video_metadata(list_path): raise def load_video_samples(list_path, samples=TRAIN_SAMPLES): - """ - Load video samples from the metadata list. - - Args: - - list_path (str): Path to the metadata JSON file. - - samples (int): Number of total samples to be extracted. - - Returns: - - list: Extracted video samples. - """ - LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" ) - + LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})") details_list = load_video_metadata(list_path) all_samples = [] num_videos = len(details_list) - frames_per_video = int(samples / num_videos) - + frames_per_video = math.ceil(samples / num_videos) LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos") - for video_details in details_list: - video_file = video_details["video_file"] - uncompressed_video_file = video_details["uncompressed_video_file"] - crf = video_details['crf'] / 63.0 + compressed_video_file = video_details["compressed_video_file"] + original_video_file = video_details["original_video_file"] + crf = video_details['crf'] / 51 preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) video_details['preset_speed'] = preset_speed - compressed_frames, uncompressed_frames = [], [] - - try: - cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file)) - cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file)) - - if not cap.isOpened() or not cap_uncompressed.isOpened(): - raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}") - - for _ in range(frames_per_video): - ret, frame_compressed = cap.read() - ret_uncompressed, frame = cap_uncompressed.read() - - if not ret or not ret_uncompressed: - continue - - # Check frame dimensions and resize if necessary - if frame.shape[:2] != (WIDTH, HEIGHT): - LOGGER.warn(f"Resizing video: {video_file}") - frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) - if frame_compressed.shape[:2] != (WIDTH, HEIGHT): - LOGGER.warn(f"Resizing video: {uncompressed_video_file}") - frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) - - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB) - - uncompressed_frames.append(normalize(frame)) - compressed_frames.append(normalize(frame_compressed)) - - all_samples.extend({ - "frame": frame, - "compressed_frame": frame_compressed, - "crf": crf, - "preset_speed": preset_speed, - "video_file": video_file - } for frame, frame_compressed in zip(uncompressed_frames, compressed_frames)) - - except Exception as e: - LOGGER.error(f"Error during video sample loading: {e}") - raise - - finally: - cap.release() - cap_uncompressed.release() + # Store video details without loading frames + all_samples.extend({ + "frames_per_video": frames_per_video, + "crf": crf, + "preset_speed": preset_speed, + "compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file), + "original_video_file": os.path.join(os.path.dirname(list_path), original_video_file) + } for _ in range(frames_per_video)) return all_samples -def normalize(frame): - """ - Normalize pixel values of the frame to range [0, 1]. - - Args: - - frame (ndarray): Image frame. - - Returns: - - ndarray: Normalized frame. - """ - LOGGER.trace(f"Normalizing frame") - return frame / 255.0 def save_model(model): try: @@ -138,6 +74,7 @@ def save_model(model): raise def main(): + global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') @@ -147,23 +84,30 @@ def main(): parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') args = parser.parse_args() + + BATCH_SIZE = args.batch_size + EPOCHS = args.epochs + TRAIN_SAMPLES = args.training_samples + LEARNING_RATE = args.learning_rate # Display training configuration LOGGER.info("Starting the training with the given configuration.") LOGGER.info("Training configuration:") - LOGGER.info(f"Batch size: {args.batch_size}") - LOGGER.info(f"Epochs: {args.epochs}") - LOGGER.info(f"Training samples: {args.training_samples}") - LOGGER.info(f"Learning rate: {args.learning_rate}") - LOGGER.info(f"Continue training from: {args.continue_training}") + LOGGER.info(f"Batch size: {BATCH_SIZE}") + LOGGER.info(f"Epochs: {EPOCHS}") + LOGGER.info(f"Training samples: {TRAIN_SAMPLES}") + LOGGER.info(f"Learning rate: {LEARNING_RATE}") + LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}") + + LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}") # Load training and validation samples LOGGER.debug("Loading training and validation samples.") - training_samples = load_video_samples("test_data/training/training.json") - validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2) + training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES) + validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10)) - train_generator = VideoDataGenerator(training_samples, args.batch_size) - val_generator = VideoDataGenerator(validation_samples, args.batch_size) + train_generator = VideoDataGenerator(training_samples, BATCH_SIZE) + val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE) # Load or initialize model if args.continue_training: @@ -172,7 +116,7 @@ def main(): model = VideoCompressionModel() # Set optimizer and compile the model - optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) + optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) model.compile(loss='mean_squared_error', optimizer=optimizer) # Define checkpoints and early stopping @@ -190,7 +134,7 @@ def main(): model.fit( train_generator, steps_per_epoch=len(train_generator), - epochs=args.epochs, + epochs=EPOCHS, validation_data=val_generator, validation_steps=len(val_generator), callbacks=[early_stop, checkpoint_callback] diff --git a/video_compression_model.py b/video_compression_model.py index bb4a0b9..d54cbb9 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -1,5 +1,6 @@ # video_compression_model.py +import cv2 import numpy as np import tensorflow as tf @@ -8,6 +9,28 @@ from global_train import LOGGER PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_CHANNELS = 3 +WIDTH = 638 +HEIGHT = 360 + +#from tensorflow.keras.mixed_precision import Policy + +#policy = Policy('mixed_float16') +#tf.keras.mixed_precision.set_global_policy(policy) + + + +def normalize(frame): + """ + Normalize pixel values of the frame to range [0, 1]. + + Args: + - frame (ndarray): Image frame. + + Returns: + - ndarray: Normalized frame. + """ + LOGGER.trace(f"Normalizing frame") + return frame / 255.0 class VideoDataGenerator(tf.keras.utils.Sequence): def __init__(self, video_details_list, batch_size): @@ -19,28 +42,59 @@ class VideoDataGenerator(tf.keras.utils.Sequence): return int(np.ceil(len(self.video_details_list) / float(self.batch_size))) def __getitem__(self, idx): - try: - start_idx = idx * self.batch_size - end_idx = (idx + 1) * self.batch_size - - batch_data = self.video_details_list[start_idx:end_idx] + start_idx = idx * self.batch_size + end_idx = (idx + 1) * self.batch_size + batch_data = self.video_details_list[start_idx:end_idx] - x1 = np.array([item["frame"] for item in batch_data]) - x2 = np.array([item["compressed_frame"] for item in batch_data]) - x3 = np.array([item["crf"] for item in batch_data]) - x4 = np.array([item["preset_speed"] for item in batch_data]) + # Determine the number of videos and frames per video + num_videos = len(batch_data) + frames_per_video = batch_data[0]['frames_per_video'] # Assuming all videos have the same number of frames - y = x2 + # Pre-allocate arrays for the batch data + x1 = np.empty((num_videos * frames_per_video, HEIGHT, WIDTH, NUM_CHANNELS)) + x2 = np.empty_like(x1) + x3 = np.empty((num_videos * frames_per_video, 1)) + x4 = np.empty_like(x3) + + # Iterate over the videos and frames, filling the pre-allocated arrays + for i, item in enumerate(batch_data): + compressed_video_file = item["compressed_video_file"] + original_video_file = item["original_video_file"] + crf = item["crf"] + preset_speed = item["preset_speed"] + + cap_compressed = cv2.VideoCapture(compressed_video_file) + cap_original = cv2.VideoCapture(original_video_file) + for j in range(frames_per_video): + compressed_ret, compressed_frame = cap_compressed.read() + original_ret, original_frame = cap_original.read() + if not compressed_ret or not original_ret: + continue + + # Check frame dimensions and resize if necessary + if original_frame.shape[:2] != (WIDTH, HEIGHT): + LOGGER.info(f"Resizing video: {original_video_file}") + original_frame = cv2.resize(original_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) + if compressed_frame.shape[:2] != (WIDTH, HEIGHT): + LOGGER.info(f"Resizing video: {compressed_video_file}") + compressed_frame = cv2.resize(compressed_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) + + original_frame = cv2.cvtColor(original_frame, cv2.COLOR_BGR2RGB) + compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) + + # Store the processed frames and metadata directly in the pre-allocated arrays + x1[i * frames_per_video + j] = normalize(original_frame) + x2[i * frames_per_video + j] = normalize(compressed_frame) + x3[i * frames_per_video + j] = crf + x4[i * frames_per_video + j] = preset_speed + + cap_original.release() + cap_compressed.release() + + y = x2 + inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4} + return inputs, y - inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4} - return inputs, y - - except IndexError: - LOGGER.error(f"Index {idx} out of bounds in VideoDataGenerator.") - raise - except Exception as e: - LOGGER.error(f"Unexpected error in VideoDataGenerator: {e}") - raise class VideoCompressionModel(tf.keras.Model): @@ -78,7 +132,43 @@ class VideoCompressionModel(tf.keras.Model): tf.keras.layers.Dropout(0.3), tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames ]) + + def call(self, inputs): + LOGGER.trace("Calling VideoCompressionModel.") + uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed'] + + # Convert frames to float32 + uncompressed_frame = tf.cast(uncompressed_frame, tf.float16) + compressed_frame = tf.cast(compressed_frame, tf.float16) + + # Embedding for preset speed + preset_speed_embedded = self.embedding(preset_speed) + preset_speed_embedded = tf.keras.layers.Flatten()(preset_speed_embedded) + + # Reshaping CRF to match the shape of preset_speed_embedded + crf_expanded = tf.keras.layers.Flatten()(tf.repeat(crf, 16, axis=-1)) + + + # Concatenating the CRF and preset speed information + integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, preset_speed_embedded]) + integrated_info = self.fc(integrated_info) + + # Integrate the CRF and preset speed information into the frames as additional channels (features) + _, height, width, _ = uncompressed_frame.shape + current_shape = tf.shape(inputs["uncompressed_frame"]) + height = current_shape[1] + width = current_shape[2] + integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1]) + + # Merge uncompressed and compressed frames + frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated]) + + compressed_representation = self.encoder(frames_merged) + reconstructed_frame = self.decoder(compressed_representation) + + return reconstructed_frame + def model_summary(self): try: LOGGER.info("Generating model summary.") @@ -90,34 +180,3 @@ class VideoCompressionModel(tf.keras.Model): except Exception as e: LOGGER.error(f"Unexpected error during model summary generation: {e}") raise - - def call(self, inputs): - LOGGER.trace("Calling VideoCompressionModel.") - uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed'] - - # Convert frames to float32 - uncompressed_frame = tf.cast(uncompressed_frame, tf.float32) - compressed_frame = tf.cast(compressed_frame, tf.float32) - - # Integrate CRF and preset speed into the network - preset_speed_embedded = self.embedding(preset_speed) - crf_expanded = tf.expand_dims(crf, -1) - integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)]) - integrated_info = self.fc(integrated_info) - - # Integrate the CRF and preset speed information into the frames as additional channels (features) - _, height, width, _ = uncompressed_frame.shape - current_shape = tf.shape(inputs["uncompressed_frame"]) - - height = current_shape[1] - width = current_shape[2] - integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1]) - - - # Merge uncompressed and compressed frames - frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated]) - - compressed_representation = self.encoder(frames_merged) - reconstructed_frame = self.decoder(compressed_representation) - - return reconstructed_frame