diff --git a/test_data/validation/validation.json b/test_data/validation/validation.json index 55a58e7..0ae23ea 100644 --- a/test_data/validation/validation.json +++ b/test_data/validation/validation.json @@ -4,5 +4,23 @@ "original_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "veryslow" + }, + { + "compressed_video_file": "Scene3_x264_crf-51_preset-ultrafast.mkv", + "original_video_file": "Scene3.mkv", + "crf": 51, + "preset_speed": "ultrafast" + }, + { + "compressed_video_file": "Scene4_x264_crf-51_preset-veryslow.mkv", + "original_video_file": "Scene4.mkv", + "crf": 51, + "preset_speed": "veryslow" + }, + { + "compressed_video_file": "Scene5_x264_crf-51_preset-veryslow.mkv", + "original_video_file": "Scene5.mkv", + "crf": 51, + "preset_speed": "veryslow" } ] diff --git a/train_model.py b/train_model.py index 9fe8e11..2b202f0 100644 --- a/train_model.py +++ b/train_model.py @@ -1,29 +1,77 @@ -# train_model.py - -import math -import os - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' - import json -import argparse +import os +import cv2 +import numpy as np + +from train_model_V2 import VideoCompressionModel +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import tensorflow as tf -from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from global_train import LOGGER # Constants -BATCH_SIZE = 4 -EPOCHS = 100 +BATCH_SIZE = 16 +EPOCHS = 5 LEARNING_RATE = 0.01 -TRAIN_SAMPLES = 100 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 +NUM_CHANNELS = 3 +WIDTH = 640 +HEIGHT = 360 + +def save_model(model): + try: + LOGGER.debug("Attempting to save the model.") + os.makedirs("models", exist_ok=True) + model.save(MODEL_SAVE_FILE, save_format='tf') + LOGGER.info("Model saved successfully!") + except Exception as e: + LOGGER.error(f"Error saving the model: {e}") + raise + +def extract_edge_features(frame): + """ + Extract edge features using Canny edge detection. + + Args: + - frame (ndarray): Image frame. + + Returns: + - ndarray: Edge feature map. + """ + edges = cv2.Canny(frame, threshold1=100, threshold2=200) + return edges.astype(np.float32) / 255.0 + +def extract_histogram_features(frame, bins=64): + """ + Extract histogram features from a frame. + + Args: + - frame (ndarray): Image frame. + - bins (int): Number of bins for the histogram. + + Returns: + - ndarray: Normalized histogram feature vector. + """ + histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255]) + return histogram.astype(np.float32) / frame.size + + def load_video_metadata(list_path): + """ + Load video metadata from a JSON file. + + Args: + - json_path (str): Path to the JSON file containing video metadata. + + Returns: + - list: List of dictionaries, each containing video details. + """ + LOGGER.trace(f"Entering: load_video_metadata({list_path})") try: with open(list_path, "r") as json_file: @@ -36,42 +84,47 @@ def load_video_metadata(list_path): except json.JSONDecodeError: LOGGER.error(f"Error decoding JSON from {list_path}.") raise + +def data_generator(videos, batch_size): + while True: + for video_details in videos: + video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"]) + cap = cv2.VideoCapture(video_path) -def load_video_samples(list_path, samples=TRAIN_SAMPLES): - LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})") - details_list = load_video_metadata(list_path) - all_samples = [] - num_videos = len(details_list) - frames_per_video = math.ceil(samples / num_videos) - LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos") - for video_details in details_list: - compressed_video_file = video_details["compressed_video_file"] - original_video_file = video_details["original_video_file"] - crf = video_details['crf'] / 51 - preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) - video_details['preset_speed'] = preset_speed + feature_batch = [] + compressed_frame_batch = [] - # Store video details without loading frames - all_samples.extend({ - "frames_per_video": frames_per_video, - "crf": crf, - "preset_speed": preset_speed, - "compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file), - "original_video_file": os.path.join(os.path.dirname(list_path), original_video_file) - } for _ in range(frames_per_video)) + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break - return all_samples + # Check frame dimensions and resize if necessary + if frame.shape[:2] != (HEIGHT, WIDTH): + frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST) + # Extract features + edge_feature = extract_edge_features(frame) + histogram_feature = extract_histogram_features(frame) + histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape + combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1) + + compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1] + + feature_batch.append(combined_feature) + compressed_frame_batch.append(compressed_frame) + + if len(feature_batch) == batch_size: + yield (np.array(feature_batch), np.array(compressed_frame_batch)) + feature_batch = [] + compressed_frame_batch = [] + + cap.release() + + # If there are frames left that don't fill a whole batch, send them anyway + if len(feature_batch) > 0: + yield (np.array(feature_batch), np.array(compressed_frame_batch)) -def save_model(model): - try: - LOGGER.debug("Attempting to save the model.") - os.makedirs("models", exist_ok=True) - model.save(MODEL_SAVE_FILE, save_format='tf') - LOGGER.info("Model saved successfully!") - except Exception as e: - LOGGER.error(f"Error saving the model: {e}") - raise def main(): global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE @@ -100,25 +153,22 @@ def main(): LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}") LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}") + LOGGER.trace("Hello, World!") - # Load training and validation samples - LOGGER.debug("Loading training and validation samples.") - training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES) - validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10)) + # Load all video metadata + all_videos = load_video_metadata("test_data/validation/validation.json") - train_generator = VideoDataGenerator(training_samples, BATCH_SIZE) - val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE) + # Split into training and validation + split_index = int(0.8 * len(all_videos)) + training_videos = all_videos[:split_index] + validation_videos = all_videos[split_index:] - # Load or initialize model - if args.continue_training: - model = tf.keras.models.load_model(args.continue_training) - else: - model = VideoCompressionModel() + model = VideoCompressionModel() # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) model.compile(loss='mean_squared_error', optimizer=optimizer) - + # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"), @@ -129,23 +179,31 @@ def main(): ) early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) + # Calculate steps per epoch for training and validation + average_frames_per_video = 2880 # Given 2 minutes @ 24 fps + total_frames_train = average_frames_per_video * len(training_videos) + total_frames_validation = average_frames_per_video * len(validation_videos) + steps_per_epoch_train = total_frames_train // BATCH_SIZE + steps_per_epoch_validation = total_frames_validation // BATCH_SIZE + # Train the model LOGGER.info("Starting model training.") model.fit( - train_generator, - steps_per_epoch=len(train_generator), - epochs=EPOCHS, - validation_data=val_generator, - validation_steps=len(val_generator), + data_generator(training_videos, BATCH_SIZE), + epochs=EPOCHS, + steps_per_epoch=steps_per_epoch_train, + validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here + validation_steps=steps_per_epoch_validation, # Add validation steps here callbacks=[early_stop, checkpoint_callback] ) LOGGER.info("Model training completed.") - + save_model(model) + if __name__ == "__main__": try: main() except Exception as e: LOGGER.error(f"Unexpected error during training: {e}") - raise + raise \ No newline at end of file diff --git a/video_compression_model.py b/video_compression_model.py index d54cbb9..601dee5 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -9,7 +9,7 @@ from global_train import LOGGER PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_CHANNELS = 3 -WIDTH = 638 +WIDTH = 640 HEIGHT = 360 #from tensorflow.keras.mixed_precision import Policy @@ -19,164 +19,31 @@ HEIGHT = 360 -def normalize(frame): - """ - Normalize pixel values of the frame to range [0, 1]. - - Args: - - frame (ndarray): Image frame. - - Returns: - - ndarray: Normalized frame. - """ - LOGGER.trace(f"Normalizing frame") - return frame / 255.0 - -class VideoDataGenerator(tf.keras.utils.Sequence): - def __init__(self, video_details_list, batch_size): - LOGGER.debug("Initializing VideoDataGenerator with batch size: {}".format(batch_size)) - self.video_details_list = video_details_list - self.batch_size = batch_size - - def __len__(self): - return int(np.ceil(len(self.video_details_list) / float(self.batch_size))) - - def __getitem__(self, idx): - start_idx = idx * self.batch_size - end_idx = (idx + 1) * self.batch_size - batch_data = self.video_details_list[start_idx:end_idx] - - # Determine the number of videos and frames per video - num_videos = len(batch_data) - frames_per_video = batch_data[0]['frames_per_video'] # Assuming all videos have the same number of frames - - # Pre-allocate arrays for the batch data - x1 = np.empty((num_videos * frames_per_video, HEIGHT, WIDTH, NUM_CHANNELS)) - x2 = np.empty_like(x1) - x3 = np.empty((num_videos * frames_per_video, 1)) - x4 = np.empty_like(x3) - - # Iterate over the videos and frames, filling the pre-allocated arrays - for i, item in enumerate(batch_data): - compressed_video_file = item["compressed_video_file"] - original_video_file = item["original_video_file"] - crf = item["crf"] - preset_speed = item["preset_speed"] - - cap_compressed = cv2.VideoCapture(compressed_video_file) - cap_original = cv2.VideoCapture(original_video_file) - for j in range(frames_per_video): - compressed_ret, compressed_frame = cap_compressed.read() - original_ret, original_frame = cap_original.read() - if not compressed_ret or not original_ret: - continue - - # Check frame dimensions and resize if necessary - if original_frame.shape[:2] != (WIDTH, HEIGHT): - LOGGER.info(f"Resizing video: {original_video_file}") - original_frame = cv2.resize(original_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) - if compressed_frame.shape[:2] != (WIDTH, HEIGHT): - LOGGER.info(f"Resizing video: {compressed_video_file}") - compressed_frame = cv2.resize(compressed_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) - - original_frame = cv2.cvtColor(original_frame, cv2.COLOR_BGR2RGB) - compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) - - # Store the processed frames and metadata directly in the pre-allocated arrays - x1[i * frames_per_video + j] = normalize(original_frame) - x2[i * frames_per_video + j] = normalize(compressed_frame) - x3[i * frames_per_video + j] = crf - x4[i * frames_per_video + j] = preset_speed - - cap_original.release() - cap_compressed.release() - - y = x2 - inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4} - return inputs, y - - - class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() LOGGER.debug("Initializing VideoCompressionModel.") - # Inputs - self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,)) - self.preset_speed_input = tf.keras.layers.InputLayer(name='preset_speed', input_shape=(1,)) - self.uncompressed_frame_input = tf.keras.layers.InputLayer(name='uncompressed_frame', input_shape=(None, None, NUM_CHANNELS)) - self.compressed_frame_input = tf.keras.layers.InputLayer(name='compressed_frame', input_shape=(None, None, NUM_CHANNELS)) - - # Embedding for speed preset and FC layer for CRF and preset speed - self.embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16) - self.fc = tf.keras.layers.Dense(32, activation='relu') + # Add an additional channel for the histogram features + input_shape_with_histogram = (HEIGHT, WIDTH, 2) # 1 channel for edges, 1 for histogram - # Encoder layers self.encoder = tf.keras.Sequential([ - tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(None, None, 2 * NUM_CHANNELS + 32)), - tf.keras.layers.BatchNormalization(), - tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'), - tf.keras.layers.BatchNormalization(), - tf.keras.layers.MaxPooling2D((2, 2)), - tf.keras.layers.Dropout(0.3) + tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram), + tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'), + tf.keras.layers.MaxPooling2D((2, 2), padding='same'), + tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'), + tf.keras.layers.MaxPooling2D((2, 2), padding='same') ]) - # Decoder layers self.decoder = tf.keras.Sequential([ - tf.keras.layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'), - tf.keras.layers.BatchNormalization(), - tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), - tf.keras.layers.BatchNormalization(), + tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'), tf.keras.layers.UpSampling2D((2, 2)), - tf.keras.layers.Dropout(0.3), - tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames + tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), + tf.keras.layers.UpSampling2D((2, 2)), + tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same') ]) def call(self, inputs): - LOGGER.trace("Calling VideoCompressionModel.") - uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed'] - - # Convert frames to float32 - uncompressed_frame = tf.cast(uncompressed_frame, tf.float16) - compressed_frame = tf.cast(compressed_frame, tf.float16) - - # Embedding for preset speed - preset_speed_embedded = self.embedding(preset_speed) - preset_speed_embedded = tf.keras.layers.Flatten()(preset_speed_embedded) - - # Reshaping CRF to match the shape of preset_speed_embedded - crf_expanded = tf.keras.layers.Flatten()(tf.repeat(crf, 16, axis=-1)) - - - # Concatenating the CRF and preset speed information - integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, preset_speed_embedded]) - integrated_info = self.fc(integrated_info) - - # Integrate the CRF and preset speed information into the frames as additional channels (features) - _, height, width, _ = uncompressed_frame.shape - current_shape = tf.shape(inputs["uncompressed_frame"]) - - height = current_shape[1] - width = current_shape[2] - integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1]) - - # Merge uncompressed and compressed frames - frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated]) - - compressed_representation = self.encoder(frames_merged) - reconstructed_frame = self.decoder(compressed_representation) - - return reconstructed_frame - - def model_summary(self): - try: - LOGGER.info("Generating model summary.") - x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame') - x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame') - x3 = tf.keras.Input(shape=(1,), name='crf') - x4 = tf.keras.Input(shape=(1,), name='preset_speed') - return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary() - except Exception as e: - LOGGER.error(f"Unexpected error during model summary generation: {e}") - raise + encoded = self.encoder(inputs) + decoded = self.decoder(encoded) + return decoded