From db43239b3d25c15561fc4833a79da9c58c88e714 Mon Sep 17 00:00:00 2001 From: Jordon Brooks Date: Wed, 23 Aug 2023 00:54:06 +0100 Subject: [PATCH] working --- DeepEncode.py | 181 ++++++++++++++++++++++++------------- featureExtraction.py | 75 ++++++++------- globalVars.py | 36 +++++++- train_model.py | 138 +++++++++++++++++++--------- video_compression_model.py | 78 ++++++---------- 5 files changed, 311 insertions(+), 197 deletions(-) diff --git a/DeepEncode.py b/DeepEncode.py index bf8232b..9bc605a 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -1,90 +1,145 @@ -# DeepEncode.py - import os +import argparse +import cv2 +import numpy as np -from featureExtraction import combined, preprocess_frame, psnr, scale_crf, scale_speed_preset, ssim -from globalVars import PRESET_SPEED_CATEGORIES - +# Set TensorFlow log level before any other imports os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import tensorflow as tf -import numpy as np -import cv2 +from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim +from globalVars import PRESET_SPEED_CATEGORIES, clear_screen from video_compression_model import VideoCompressionModel, combine_batch # Constants COMPRESSED_VIDEO_FILE = 'compressed_video.avi' MAX_FRAMES = 200 # Limit the number of frames processed -CRF = 51 -SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast") - -# Load the trained model -MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined}) - -# Load the uncompressed video +CRF = 10 +SPEED = "ultrafast" +MODEL_PATH = 'models/model.tf' UNCOMPRESSED_VIDEO_FILE = 'test_data/x264_crf-5_preset-veryslow.mkv' +DISPLAY_OUTPUT = False +CROP_DIMENSIONS = None + + + +def parse_arguments(): + global COMPRESSED_VIDEO_FILE, MAX_FRAMES, CRF, SPEED, MODEL_PATH, UNCOMPRESSED_VIDEO_FILE, DISPLAY_OUTPUT, CROP_DIMENSIONS + parser = argparse.ArgumentParser(description='Deep Encoding of Videos') + parser.add_argument('-o', '--compressed_video_file', default=COMPRESSED_VIDEO_FILE, help='Path to the compressed video file') + parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Maximum number of frames to process') + parser.add_argument('-c', '--crf', type=int, default=CRF, help='CRF value for video compression') + parser.add_argument('-s', '--speed', default=SPEED, choices=PRESET_SPEED_CATEGORIES, help='Video compression speed category') + parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model') + parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file') + parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen') + parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False) + + args = parser.parse_args() + + COMPRESSED_VIDEO_FILE = args.compressed_video_file + MAX_FRAMES = args.max_frames + CRF = args.crf + SPEED = args.speed + MODEL_PATH = args.model_path + UNCOMPRESSED_VIDEO_FILE = args.uncompressed_video_file + DISPLAY_OUTPUT = args.display_output + + if not args.keep_black_bars: + CROP_DIMENSIONS = find_crop_dimensions(UNCOMPRESSED_VIDEO_FILE) + +def crop_black_bars(frame): + # Convert to grayscale for easier processing + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + # Threshold the image to make everything below a certain gray value black, and everything else white + _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) + + # Find the contours of the white regions + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Find the bounding box that contains all the contours + x_min = y_min = float('inf') + x_max = y_max = 0 + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + x_min = min(x_min, x) + y_min = min(y_min, y) + x_max = max(x_max, x + w) + y_max = max(y_max, y + h) + + return x_min, y_min, x_max, y_max + +def find_crop_dimensions(video_file): + cap = cv2.VideoCapture(video_file) + while True: + ret, frame = cap.read() + if not ret: + print("Error: Unable to find a non-black frame.") + cap.release() + exit() + + # Check if the frame is entirely black + if np.any(frame > 0): + x_min, y_min, x_max, y_max = crop_black_bars(frame) + cap.release() + return x_min, y_min, x_max, y_max + def load_frame_from_video(video_file, frame_num): cap = cv2.VideoCapture(video_file) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) ret, frame = cap.read() - if not ret: - return None cap.release() - - return frame + return frame if ret else None -def predict_frame(uncompressed_frame): - #display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) - #cv2.imshow("uncomp", uncompressed_frame) - scaled_crf = scale_crf(CRF) - scaled_speed = scale_speed_preset(SPEED) - +def predict_frame(uncompressed_frame, model, crf, speed): + scaled_crf = scale_crf(crf) + scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed)) frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False) - - compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0] - - compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR) - - compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8) - - cv2.imshow("comp", compressed_frame) - cv2.waitKey(1) - - return compressed_frame + compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0] + return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8) -cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE) -total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) -height, width = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) -fps = int(cap.get(cv2.CAP_PROP_FPS)) -cap.release() -fourcc = cv2.VideoWriter_fourcc(*'XVID') -out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, fps, (width, height), True) - -if not out.isOpened(): - print("Error: VideoWriter could not be opened.") - exit() - -if MAX_FRAMES != 0 and total_frames > MAX_FRAMES: - total_frames = MAX_FRAMES - -for i in range(total_frames): - uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i) - compressed_frame = predict_frame(uncompressed_frame) +def main(): + model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss}) + cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE) - compressed_frame = cv2.resize(compressed_frame, (width, height)) + total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES) + height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS)) - #compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8) - - #compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR) - - out.write(compressed_frame) + cap.release() - #if i % 10 == 0: # Print progress every 10 frames - # print(f"Processed {i} / {total_frames} frames") + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, fps, (width, height), True) -out.release() -print("Compression completed.") + if not out.isOpened(): + print("Error: VideoWriter could not be opened.") + exit() + for i in range(total_frames): + uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i) + + if CROP_DIMENSIONS: + x_min, y_min, x_max, y_max = CROP_DIMENSIONS + uncompressed_frame = uncompressed_frame[y_min:y_max, x_min:x_max] + + compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED) + compressed_frame = cv2.resize(compressed_frame, (width, height)) + + out.write(compressed_frame) + + if DISPLAY_OUTPUT: + cv2.imshow('Compressed Video', compressed_frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + out.release() + print("Compression completed.") + + +if __name__ == '__main__': + clear_screen() + parse_arguments() + main() diff --git a/featureExtraction.py b/featureExtraction.py index 8b5311e..e289e66 100644 --- a/featureExtraction.py +++ b/featureExtraction.py @@ -9,51 +9,21 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import tensorflow as tf from tensorflow.keras import backend as K -from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH +from globalVars import HEIGHT, LOGGER, NUM_PRESET_SPEEDS, WIDTH def scale_crf(crf): return crf / 51 + def scale_speed_preset(speed): return speed / NUM_PRESET_SPEEDS -def extract_edge_features(frame): - """ - Extract edge features using Canny edge detection. - - Args: - - frame (ndarray): Image frame. - - Returns: - - ndarray: Edge feature map. - """ - edges = cv2.Canny(frame, threshold1=100, threshold2=200) - return edges.astype(np.float32) / 255.0 - -def extract_histogram_features(frame, bins=64): - """ - Extract histogram features from a frame with 3 channels. - - Args: - - frame (ndarray): Image frame with shape (height, width, 3). - - bins (int): Number of bins for the histogram. - - Returns: - - ndarray: Normalized histogram feature vector. - """ - feature_vector = [] - for channel in range(3): - histogram, _ = np.histogram(frame[:,:,channel].flatten(), bins=bins, range=[0, 255]) - normalized_histogram = histogram.astype(np.float32) / frame[:,:,channel].size - feature_vector.extend(normalized_histogram) - - return np.array(feature_vector) - - def psnr(y_true, y_pred): + #LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}") max_pixel = 1.0 - return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0) + mse = K.mean(K.square(y_pred - y_true)) + return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0) def ssim(y_true, y_pred): @@ -64,14 +34,41 @@ def combined(y_true, y_pred): return (psnr(y_true, y_pred) + ssim(y_true, y_pred)) / 2 -def preprocess_frame(frame, resize=True): - #Preprocesses a single frame, cropping it if needed +def combined_loss(y_true, y_pred): + return -combined(y_true, y_pred) # The goal is to maximize the combined value + + +def detect_noise(image, threshold=15): + # Convert to grayscale if it's a color image + if len(image.shape) == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # Compute the standard deviation + std_dev = np.std(image) + + # If the standard deviation is higher than a threshold, it might be considered noisy + return std_dev > threshold + + +def frame_difference(frame1, frame2): + # Ensure both frames are of the same size and type + if frame1.shape != frame2.shape: + raise ValueError("Frames must have the same dimensions and number of channels") + + # Calculate the absolute difference between the frames + difference = cv2.absdiff(frame1, frame2) + + return difference + + +def preprocess_frame(frame, resize=True, scale=True): # Check frame dimensions and resize if necessary if resize and frame.shape[:2] != (HEIGHT, WIDTH): frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR) + if scale: # Scale frame to [0, 1] - compressed_frame = frame / 255.0 + frame = frame / 255.0 - return compressed_frame + return frame diff --git a/globalVars.py b/globalVars.py index 626cb38..b57cc47 100644 --- a/globalVars.py +++ b/globalVars.py @@ -1,6 +1,9 @@ # gobalVars.py +import json import log +import platform +import os LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True) @@ -9,4 +12,35 @@ NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_COLOUR_CHANNELS = 3 WIDTH = 640 HEIGHT = 360 -MAX_FRAMES = 0 \ No newline at end of file +MAX_FRAMES = 0 + +def clear_screen(): + system_name = platform.system() + if system_name == "Windows": + os.system('cls') + else: + os.system('clear') + +def load_video_metadata(list_path): + """ + Load video metadata from a JSON file. + + Args: + - json_path (str): Path to the JSON file containing video metadata. + + Returns: + - list: List of dictionaries, each containing video details. + """ + + LOGGER.trace(f"Entering: load_video_metadata({list_path})") + try: + with open(list_path, "r") as json_file: + file = json.load(json_file) + LOGGER.trace(f"load_video_metadata returning: {file}") + return file + except FileNotFoundError: + LOGGER.error(f"Metadata file {list_path} not found.") + raise + except json.JSONDecodeError: + LOGGER.error(f"Error decoding JSON from {list_path}.") + raise \ No newline at end of file diff --git a/train_model.py b/train_model.py index e9e4f8c..e465752 100644 --- a/train_model.py +++ b/train_model.py @@ -7,17 +7,25 @@ TODO: """ import argparse -import json import os import random +import shutil +import cv2 +import subprocess +import signal -from featureExtraction import combined, psnr, ssim +import numpy as np + +from featureExtraction import combined, combined_loss, psnr, ssim os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import gc import tensorflow as tf -from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback +from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard +from tensorflow.keras import backend as K +from tensorflow.summary import image as tf_image_summary + gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: @@ -30,7 +38,7 @@ if gpus: from video_compression_model import VideoCompressionModel, create_dataset -from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER +from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_video_metadata # Constants BATCH_SIZE = 25 @@ -41,50 +49,71 @@ DECAY_RATE = 0.9 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 +RANDOM_SEED = 4576 +MODEL = None +LOG_DIR = './logs' + + +class ImageLoggingCallback(Callback): + def __init__(self, validation_dataset, log_dir): + super().__init__() + self.validation_dataset = validation_dataset + self.log_dir = log_dir + self.writer = tf.summary.create_file_writer(self.log_dir) + + def convert_images(self, images): + converted = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + return np.stack(converted, axis=0) + + def on_epoch_end(self, epoch, logs=None): + itter = iter(self.validation_dataset) + random_idx = np.random.randint(0, BATCH_SIZE) + + # Loop through the dataset until the chosen index + for i, data in enumerate(self.validation_dataset): + if i == random_idx: + validation_data = data + break + + batch_input_images, batch_gt_labels = validation_data + + batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8) + batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8) + + reconstructed_frame = MODEL.predict(validation_data[0]) + reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8) + + batch_input_images = self.convert_images(batch_input_images) + batch_gt_labels = self.convert_images(batch_gt_labels) + reconstructed_frame = self.convert_images(reconstructed_frame) + + # Log images to TensorBoard + with self.writer.as_default(): + tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1) + tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1) + tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3) + self.writer.flush() + + class GarbageCollectorCallback(Callback): def on_epoch_end(self, epoch, logs=None): LOGGER.debug(f"GC") gc.collect() -def save_model(model): +def save_model(): try: LOGGER.debug("Attempting to save the model.") os.makedirs("models", exist_ok=True) - model.save(MODEL_SAVE_FILE, save_format='tf') + MODEL.save(MODEL_SAVE_FILE, save_format='tf') LOGGER.info("Model saved successfully!") except Exception as e: LOGGER.error(f"Error saving the model: {e}") raise - -def load_video_metadata(list_path): - """ - Load video metadata from a JSON file. - - Args: - - json_path (str): Path to the JSON file containing video metadata. - - Returns: - - list: List of dictionaries, each containing video details. - """ - - LOGGER.trace(f"Entering: load_video_metadata({list_path})") - try: - with open(list_path, "r") as json_file: - file = json.load(json_file) - LOGGER.trace(f"load_video_metadata returning: {file}") - return file - except FileNotFoundError: - LOGGER.error(f"Metadata file {list_path} not found.") - raise - except json.JSONDecodeError: - LOGGER.error(f"Error decoding JSON from {list_path}.") - raise - def main(): - global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE + global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE, MODEL # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') @@ -119,11 +148,10 @@ def main(): # Load all video metadata all_videos = load_video_metadata("test_data/validation/validation.json") - # Specify the random seed - random_seed = 4576 # You can change this to any desired value + tf.random.set_seed(RANDOM_SEED) # Shuffle the data using the specified seed - random.shuffle(all_videos, random.seed(random_seed)) + random.shuffle(all_videos, random.seed(RANDOM_SEED)) # Split into training and validation split_index = int(0.6 * len(all_videos)) @@ -136,12 +164,14 @@ def main(): training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES) validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES) + + tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0, write_graph=True, update_freq='epoch') if args.continue_training: - model = tf.keras.models.load_model(args.continue_training) + MODEL = tf.keras.models.load_model(args.continue_training) else: - model = VideoCompressionModel() + MODEL = VideoCompressionModel() # Define exponential decay schedule @@ -149,13 +179,13 @@ def main(): initial_learning_rate=LEARNING_RATE, decay_steps=DECAY_STEPS, decay_rate=DECAY_RATE, - staircase=False + staircase=True ) # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) - model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined]) + MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined]) # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( @@ -167,6 +197,9 @@ def main(): ) early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True) + ImageSnapshots = ImageLoggingCallback(validation_dataset, LOG_DIR) + + # Custom garbage collection callback gc_callback = GarbageCollectorCallback() @@ -174,20 +207,37 @@ def main(): # Train the model LOGGER.info("Starting model training.") - model.fit( + MODEL.fit( training_dataset, epochs=EPOCHS, validation_data=validation_dataset, - callbacks=[early_stop, checkpoint_callback, gc_callback] + callbacks=[early_stop, checkpoint_callback, gc_callback, tensorboard_callback, ImageSnapshots] ) LOGGER.info("Model training completed.") - save_model(model) - + save_model() + +def preMain(): + # Delete the existing logs directory and create a new one + if os.path.exists(LOG_DIR): + shutil.rmtree(LOG_DIR) + os.makedirs(LOG_DIR, exist_ok=True) + + # Start TensorBoard as a subprocess + LOGGER.info("Running tensorboard at: http://localhost:6006/") + tensorboard_process = subprocess.Popen(['tensorboard', '--logdir', './logs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid) + return tensorboard_process if __name__ == "__main__": + clear_screen() + + tensorboard_process = preMain() + try: main() except Exception as e: LOGGER.error(f"Unexpected error during training: {e}") - raise \ No newline at end of file + raise + finally: + # Ensure TensorBoard process is terminated when main script ends + os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM) diff --git a/video_compression_model.py b/video_compression_model.py index a706dd2..f3fb201 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -5,6 +5,7 @@ import os import cv2 import numpy as np import tensorflow as tf +from tensorflow.keras import layers from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH @@ -28,36 +29,6 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True): return np.concatenate(combined, axis=-1) -def process_video(video): - base_dir = os.path.dirname("test_data/validation/validation.json") - - cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) - cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) - - compressed_frames = [] - uncompressed_frames = [] - - while True: - ret_compressed, compressed_frame = cap_compressed.read() - ret_uncompressed, uncompressed_frame = cap_uncompressed.read() - - if not ret_compressed or not ret_uncompressed: - break - - CRF = scale_crf(video["crf"]) - SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"])) - - compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) - uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) - - compressed_frames.append(compressed_combined) - uncompressed_frames.append(uncompressed_combined) - - cap_compressed.release() - cap_uncompressed.release() - - return uncompressed_frames, compressed_frames - def frame_generator(videos, max_frames=None): base_dir = "test_data/validation/" @@ -76,10 +47,10 @@ def frame_generator(videos, max_frames=None): CRF = scale_crf(video["crf"]) SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"])) - compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) - uncompressed_combined = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) + validation = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) + training = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) - yield uncompressed_combined, compressed_combined + yield training, validation frame_count += 1 if max_frames is not None and frame_count >= max_frames: @@ -104,7 +75,7 @@ def create_dataset(videos, batch_size, max_frames=None): output_signature=output_signature ) - dataset = dataset.shuffle(100).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) + dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) #.prefetch(tf.data.experimental.AUTOTUNE) return dataset @@ -113,29 +84,36 @@ def create_dataset(videos, batch_size, max_frames=None): class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() - LOGGER.debug("Initializing VideoCompressionModel.") - - # Input shape (includes channels for CRF and SPEED_PRESET) - input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2) - + input_shape = (None, None, NUM_COLOUR_CHANNELS + 2) + # Encoder part of the model self.encoder = tf.keras.Sequential([ - tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram), - tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'), - tf.keras.layers.MaxPooling2D((2, 2), padding='same'), - tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'), - tf.keras.layers.MaxPooling2D((2, 2), padding='same') + layers.InputLayer(input_shape=input_shape), + layers.Conv2D(64, (3, 3), padding='same'), + #layers.BatchNormalization(), + layers.LeakyReLU(), + layers.MaxPooling2D((2, 2), padding='same'), + layers.SeparableConv2D(32, (3, 3), padding='same'), # Using Separable Convolution + #layers.BatchNormalization(), + layers.LeakyReLU(), + layers.MaxPooling2D((2, 2), padding='same') ]) # Decoder part of the model self.decoder = tf.keras.Sequential([ - tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'), - tf.keras.layers.UpSampling2D((2, 2)), - tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), - tf.keras.layers.UpSampling2D((2, 2)), - tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS, (3, 3), activation='sigmoid', padding='same') + layers.Conv2DTranspose(32, (3, 3), padding='same'), + #layers.BatchNormalization(), + layers.LeakyReLU(), + layers.Conv2DTranspose(64, (3, 3), dilation_rate=2, padding='same'), # Using Dilated Convolution + #layers.BatchNormalization(), + layers.LeakyReLU(), + # Use Sub-Pixel Convolutional Layer + layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 16, (3, 3), padding='same'), # 16 times the number of color channels + layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=4)) # Sub-Pixel Convolutional Layer with block_size=4 ]) def call(self, inputs): - return self.decoder(self.encoder(inputs)) + encoded = self.encoder(inputs) + return self.decoder(encoded) +