From 60c6c97071ddbd164622f2d44e45e0f40f0f2f7c Mon Sep 17 00:00:00 2001 From: Jordon Brooks Date: Sun, 30 Jul 2023 16:48:51 +0100 Subject: [PATCH] Improved model --- .gitignore | 4 + DeepEncode.py | 3 +- global_train.py | 3 + log.py | 66 +++++++++++ test_data/training.json | 74 ++++++++++++ test_data/validation.json | 8 ++ train_model.py | 226 ++++++++++++++++++++++--------------- video_compression_model.py | 55 +++++---- 8 files changed, 327 insertions(+), 112 deletions(-) create mode 100644 global_train.py create mode 100644 log.py create mode 100644 test_data/training.json create mode 100644 test_data/validation.json diff --git a/.gitignore b/.gitignore index dae2ae5..a85e841 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ !DeepEncode.py !train_model.py !video_compression_model.py +!global_train.py +!log.py +!test_data/training.json +!test_data/validation.json \ No newline at end of file diff --git a/DeepEncode.py b/DeepEncode.py index 93fe121..e529318 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -41,8 +41,7 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value): #cv2.waitKey(10) compressed_frame = model.predict({ - "uncompressed_frame": uncompressed_frame, - "compressed_frame": uncompressed_frame, + "compressed_frame": uncompressed_frame, "crf": crf_array, "preset_speed": preset_speed_array }) diff --git a/global_train.py b/global_train.py new file mode 100644 index 0000000..5dc64c4 --- /dev/null +++ b/global_train.py @@ -0,0 +1,3 @@ +import log + +LOGGER = log.Logger(level="INFO", logfile="training.log", reset_logfile=True) \ No newline at end of file diff --git a/log.py b/log.py new file mode 100644 index 0000000..9297f7e --- /dev/null +++ b/log.py @@ -0,0 +1,66 @@ +import datetime +import inspect + +class TerminalColors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +class Logger: + LEVELS = {"TRACE": 0, "INFO": 1, "DEBUG": 2, "WARN": 3, "ERROR": 4} + COLORS = {"TRACE": TerminalColors.HEADER, "INFO": TerminalColors.OKCYAN, "DEBUG": TerminalColors.OKGREEN, + "WARN": TerminalColors.WARNING, "ERROR": TerminalColors.FAIL} + + def __init__(self, level="INFO", logfile=None, log_format="{timestamp} {level} {message}", reset_logfile=False): + self.level = level + self.logfile = logfile + self.log_format = log_format + + if reset_logfile and logfile: + with open(logfile, 'w') as file: + file.truncate(0) # This will clear the content of the file + + def _get_caller_info(self): + frame = inspect.stack()[3] + filename = frame.filename.split('/')[-1] # Extracts the last part after the final '/' + line_number = frame.lineno + return filename, line_number + + def _log_to_file(self, message): + if self.logfile: + with open(self.logfile, 'a') as file: + file.write(message + '\n') + + def _print_log(self, level_name, *args): + if Logger.LEVELS[level_name] >= Logger.LEVELS[self.level]: + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + message = " ".join(map(str, args)) + + if level_name in ["TRACE", "DEBUG"]: + filename, line_number = self._get_caller_info() + message = f"({filename}:{line_number}) {message}" + + log_message = self.log_format.format(timestamp=timestamp, level=level_name, message=message) + print(f"{Logger.COLORS[level_name]}{log_message}{TerminalColors.ENDC}") + self._log_to_file(log_message) + + def trace(self, *args): + self._print_log("TRACE", *args) + + def info(self, *args): + self._print_log("INFO", *args) + + def warn(self, *args): + self._print_log("WARN", *args) + + def debug(self, *args): + self._print_log("DEBUG", *args) + + def error(self, *args): + self._print_log("ERROR", *args) \ No newline at end of file diff --git a/test_data/training.json b/test_data/training.json new file mode 100644 index 0000000..208bb1e --- /dev/null +++ b/test_data/training.json @@ -0,0 +1,74 @@ +[ + { + "video_file": "x264_crf-51_preset-ultrafast.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "ultrafast" + }, + { + "video_file": "x264_crf-16_preset-veryslow.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 16, + "preset_speed": "veryslow" + }, + { + "video_file": "x264_crf-18_preset-ultrafast.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 18, + "preset_speed": "ultrafast" + }, + { + "video_file": "x264_crf-18_preset-veryslow.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 18, + "preset_speed": "veryslow" + }, + { + "video_file": "x264_crf-50_preset-veryslow.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 50, + "preset_speed": "veryslow" + }, + { + "video_file": "x264_crf-51_preset-fast.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "fast" + }, + { + "video_file": "x264_crf-51_preset-faster.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "faster" + }, + { + "video_file": "x264_crf-51_preset-medium.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "medium" + }, + { + "video_file": "x264_crf-51_preset-slow.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "slow" + }, + { + "video_file": "x264_crf-51_preset-slower.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "slower" + }, + { + "video_file": "x264_crf-51_preset-superfast.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "superfast" + }, + { + "video_file": "x264_crf-51_preset-veryfast.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "veryfast" + } +] diff --git a/test_data/validation.json b/test_data/validation.json new file mode 100644 index 0000000..b8912d3 --- /dev/null +++ b/test_data/validation.json @@ -0,0 +1,8 @@ +[ + { + "video_file": "x264_crf-16_preset-veryslow.mkv", + "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "crf": 16, + "preset_speed": "veryslow" + } +] diff --git a/train_model.py b/train_model.py index d130a47..17dae8f 100644 --- a/train_model.py +++ b/train_model.py @@ -1,4 +1,9 @@ +# train_model.py + import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' + import json import numpy as np import cv2 @@ -7,82 +12,122 @@ import tensorflow as tf from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint -print("GPUs Detected:", tf.config.list_physical_devices('GPU')) +from global_train import LOGGER # Constants BATCH_SIZE = 4 EPOCHS = 100 LEARNING_RATE = 0.000001 -TRAIN_SAMPLES = 500 +TRAIN_SAMPLES = 50 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" -CONTINUE_TRAINING = None +EARLY_STOP = 10 -def load_list(list_path): - with open(list_path, "r") as json_file: - video_details_list = json.load(json_file) - return video_details_list +def load_video_metadata(list_path): + LOGGER.trace(f"Entering: load_video_metadata({list_path})") + try: + with open(list_path, "r") as json_file: + file = json.load(json_file) + LOGGER.trace(f"load_video_metadata returning: {file}") + return file + except FileNotFoundError: + LOGGER.error(f"Metadata file {list_path} not found.") + raise + except json.JSONDecodeError: + LOGGER.error(f"Error decoding JSON from {list_path}.") + raise -def load_video_from_list(list_path, samples = TRAIN_SAMPLES): - details_list = load_list(list_path) - all_details = [] +def load_video_samples(list_path, samples=TRAIN_SAMPLES): + """ + Load video samples from the metadata list. + Args: + - list_path (str): Path to the metadata JSON file. + - samples (int): Number of total samples to be extracted. + + Returns: + - list: Extracted video samples. + """ + LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" ) + + details_list = load_video_metadata(list_path) + all_samples = [] num_videos = len(details_list) frames_per_video = int(samples / num_videos) - - print(f"Loading {frames_per_video} frames across {num_videos} videos") + + LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos") for video_details in details_list: - VIDEO_FILE = video_details["video_file"] - UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"] - CRF = video_details['crf'] / 63.0 - PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) - video_details['preset_speed'] = PRESET_SPEED + video_file = video_details["video_file"] + uncompressed_video_file = video_details["uncompressed_video_file"] + crf = video_details['crf'] / 63.0 + preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) + video_details['preset_speed'] = preset_speed + + compressed_frames, uncompressed_frames = [], [] - frames = [] - frames_compressed = [] - - cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE)) - cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE)) - - for _ in range(frames_per_video): - ret, frame_compressed = cap.read() - ret_uncompressed, frame = cap_uncompressed.read() - - if not ret or not ret_uncompressed: - continue - - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB) - - frames.append(preprocess(frame)) - frames_compressed.append(preprocess(frame_compressed)) - - for uncompressed_frame, compressed_frame in zip(frames, frames_compressed): - all_details.append({ - "frame": uncompressed_frame, - "compressed_frame": compressed_frame, - "crf": CRF, - "preset_speed": PRESET_SPEED, - "video_file": VIDEO_FILE - }) - - cap.release() - cap_uncompressed.release() + try: + cap = cv2.VideoCapture(os.path.join("test_data/", video_file)) + cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", uncompressed_video_file)) + + if not cap.isOpened() or not cap_uncompressed.isOpened(): + raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}") - return all_details + for _ in range(frames_per_video): + ret, frame_compressed = cap.read() + ret_uncompressed, frame = cap_uncompressed.read() -def preprocess(frame): + if not ret or not ret_uncompressed: + continue + + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB) + + uncompressed_frames.append(normalize(frame)) + compressed_frames.append(normalize(frame_compressed)) + + all_samples.extend({ + "frame": frame, + "compressed_frame": frame_compressed, + "crf": crf, + "preset_speed": preset_speed, + "video_file": video_file + } for frame, frame_compressed in zip(uncompressed_frames, compressed_frames)) + + except Exception as e: + LOGGER.error(f"Error during video sample loading: {e}") + raise + + finally: + cap.release() + cap_uncompressed.release() + + return all_samples + +def normalize(frame): + """ + Normalize pixel values of the frame to range [0, 1]. + + Args: + - frame (ndarray): Image frame. + + Returns: + - ndarray: Normalized frame. + """ + LOGGER.trace(f"Normalizing frame") return frame / 255.0 def save_model(model): - os.makedirs("models", exist_ok=True) - model.save(MODEL_SAVE_FILE, save_format='tf') - print("Model saved successfully!") + try: + LOGGER.debug("Attempting to save the model.") + os.makedirs("models", exist_ok=True) + model.save(MODEL_SAVE_FILE, save_format='tf') + LOGGER.info("Model saved successfully!") + except Exception as e: + LOGGER.error(f"Error saving the model: {e}") + raise def main(): - global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING - # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') @@ -92,37 +137,35 @@ def main(): parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') args = parser.parse_args() - - # Use the parsed arguments in your script - BATCH_SIZE = args.batch_size - EPOCHS = args.epochs - TRAIN_SAMPLES = args.training_samples - LEARNING_RATE = args.learning_rate - CONTINUE_TRAINING = args.continue_training - - print("Training configuration:") - print(f"Batch size: {BATCH_SIZE}") - print(f"Epochs: {EPOCHS}") - print(f"Training samples: {TRAIN_SAMPLES}") - print(f"Learning rate: {LEARNING_RATE}") - print(f"Continue training from: {CONTINUE_TRAINING}") - - all_video_details_train = load_video_from_list("test_data/training.json") - all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2) - train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE) - val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE) - - if CONTINUE_TRAINING: - print("loading model:", CONTINUE_TRAINING) - model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file + # Display training configuration + LOGGER.info("Starting the training with the given configuration.") + LOGGER.info("Training configuration:") + LOGGER.info(f"Batch size: {args.batch_size}") + LOGGER.info(f"Epochs: {args.epochs}") + LOGGER.info(f"Training samples: {args.training_samples}") + LOGGER.info(f"Learning rate: {args.learning_rate}") + LOGGER.info(f"Continue training from: {args.continue_training}") + + # Load training and validation samples + LOGGER.debug("Loading training and validation samples.") + training_samples = load_video_samples("test_data/training.json") + validation_samples = load_video_samples("test_data/validation.json", args.training_samples // 2) + + train_generator = VideoDataGenerator(training_samples, args.batch_size) + val_generator = VideoDataGenerator(validation_samples, args.batch_size) + + # Load or initialize model + if args.continue_training: + model = tf.keras.models.load_model(args.continue_training) else: model = VideoCompressionModel() - - # Define the optimizer with a specific learning rate - optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) - - os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True) + + # Set optimizer and compile the model + optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) + model.compile(loss='mean_squared_error', optimizer=optimizer) + + # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"), save_weights_only=False, @@ -130,24 +173,25 @@ def main(): verbose=1, save_format="tf" ) + early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) - #tf.config.run_functions_eagerly(True) - - model.compile(loss='mean_squared_error', optimizer=optimizer) - early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True) - - print("\nTraining the model...") + # Train the model + LOGGER.info("Starting model training.") model.fit( train_generator, steps_per_epoch=len(train_generator), - epochs=EPOCHS, + epochs=args.epochs, validation_data=val_generator, validation_steps=len(val_generator), callbacks=[early_stop, checkpoint_callback] ) - print("\nTraining completed!") + LOGGER.info("Model training completed.") save_model(model) if __name__ == "__main__": - main() + try: + main() + except Exception as e: + LOGGER.error(f"Unexpected error during training: {e}") + raise diff --git a/video_compression_model.py b/video_compression_model.py index fa22a4c..bb4a0b9 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -3,12 +3,15 @@ import numpy as np import tensorflow as tf +from global_train import LOGGER + PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_CHANNELS = 3 class VideoDataGenerator(tf.keras.utils.Sequence): def __init__(self, video_details_list, batch_size): + LOGGER.debug("Initializing VideoDataGenerator with batch size: {}".format(batch_size)) self.video_details_list = video_details_list self.batch_size = batch_size @@ -16,25 +19,34 @@ class VideoDataGenerator(tf.keras.utils.Sequence): return int(np.ceil(len(self.video_details_list) / float(self.batch_size))) def __getitem__(self, idx): - start_idx = idx * self.batch_size - end_idx = (idx + 1) * self.batch_size + try: + start_idx = idx * self.batch_size + end_idx = (idx + 1) * self.batch_size + + batch_data = self.video_details_list[start_idx:end_idx] + + x1 = np.array([item["frame"] for item in batch_data]) + x2 = np.array([item["compressed_frame"] for item in batch_data]) + x3 = np.array([item["crf"] for item in batch_data]) + x4 = np.array([item["preset_speed"] for item in batch_data]) + + y = x2 + + inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4} + return inputs, y - batch_data = self.video_details_list[start_idx:end_idx] - - x1 = np.array([item["frame"] for item in batch_data]) - x2 = np.array([item["compressed_frame"] for item in batch_data]) - x3 = np.array([item["crf"] for item in batch_data]) - x4 = np.array([item["preset_speed"] for item in batch_data]) - - y = x2 - - inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4} - return inputs, y + except IndexError: + LOGGER.error(f"Index {idx} out of bounds in VideoDataGenerator.") + raise + except Exception as e: + LOGGER.error(f"Unexpected error in VideoDataGenerator: {e}") + raise class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() + LOGGER.debug("Initializing VideoCompressionModel.") # Inputs self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,)) @@ -68,14 +80,19 @@ class VideoCompressionModel(tf.keras.Model): ]) def model_summary(self): - x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame') - x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame') - x3 = tf.keras.Input(shape=(1,), name='crf') - x4 = tf.keras.Input(shape=(1,), name='preset_speed') - return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary() - + try: + LOGGER.info("Generating model summary.") + x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame') + x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame') + x3 = tf.keras.Input(shape=(1,), name='crf') + x4 = tf.keras.Input(shape=(1,), name='preset_speed') + return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary() + except Exception as e: + LOGGER.error(f"Unexpected error during model summary generation: {e}") + raise def call(self, inputs): + LOGGER.trace("Calling VideoCompressionModel.") uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed'] # Convert frames to float32