diff --git a/featureExtraction.py b/featureExtraction.py index e289e66..b2fd788 100644 --- a/featureExtraction.py +++ b/featureExtraction.py @@ -23,6 +23,7 @@ def psnr(y_true, y_pred): #LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}") max_pixel = 1.0 mse = K.mean(K.square(y_pred - y_true)) + mse = tf.cast(mse, tf.float32) # Cast mse to tf.float32 return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0) @@ -37,6 +38,14 @@ def combined(y_true, y_pred): def combined_loss(y_true, y_pred): return -combined(y_true, y_pred) # The goal is to maximize the combined value +# Option 1: Weight more towards PSNR +def combined_loss_weighted_psnr(y_true, y_pred): + return -0.7 * psnr(y_true, y_pred) - 0.3 * ssim(y_true, y_pred) + +# Option 2: Weight more towards SSIM +def combined_loss_weighted_ssim(y_true, y_pred): + return -0.3 * psnr(y_true, y_pred) - 0.7 * ssim(y_true, y_pred) + def detect_noise(image, threshold=15): # Convert to grayscale if it's a color image @@ -66,6 +75,8 @@ def preprocess_frame(frame, resize=True, scale=True): # Check frame dimensions and resize if necessary if resize and frame.shape[:2] != (HEIGHT, WIDTH): frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR) + + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if scale: # Scale frame to [0, 1] diff --git a/globalVars.py b/globalVars.py index b78664a..9077fb0 100644 --- a/globalVars.py +++ b/globalVars.py @@ -1,9 +1,13 @@ # gobalVars.py import json +import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' + +import tensorflow as tf import log import platform -import os LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True) @@ -13,6 +17,7 @@ NUM_COLOUR_CHANNELS = 3 WIDTH = 640 HEIGHT = 360 MAX_FRAMES = 0 +DATATYPE = tf.float16 def clear_screen(): system_name = platform.system() diff --git a/train_model.py b/train_model.py index c58758e..d652d2b 100644 --- a/train_model.py +++ b/train_model.py @@ -16,7 +16,7 @@ import signal import numpy as np -from featureExtraction import combined, combined_loss, psnr, ssim +from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, ssim os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' @@ -49,7 +49,7 @@ DECAY_RATE = 0.9 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 -RANDOM_SEED = 4576 +RANDOM_SEED = 3545 MODEL = None LOG_DIR = './logs' @@ -66,41 +66,27 @@ class ImageLoggingCallback(Callback): return np.stack(converted, axis=0) def on_epoch_end(self, epoch, logs=None): - random_idx = np.random.randint(0, MAX_FRAMES - 1) + # Get the first batch from the validation dataset + validation_data = next(iter(self.validation_dataset.take(1))) - validation_data = None - dataset_size = 0 # to keep track of the dataset size - - # Loop through the dataset until the chosen index - for i, data in enumerate(self.validation_dataset): - if i == random_idx: - validation_data = data - break - dataset_size += 1 - - if validation_data is None: - print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.") - validation_data = data # assigning the last data seen in the loop to validation_data - - batch_input_images, batch_gt_labels = validation_data - - batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8) + # Extract the inputs from the batch_input_images dictionary + actual_images = validation_data[0]['image'] + batch_gt_labels = validation_data[1] + + actual_images = np.clip(actual_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8) batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8) - + + # Providing all three inputs to the model for prediction reconstructed_frame = MODEL.predict(validation_data[0]) reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8) # Save the reconstructed frame to the specified folder reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png") cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example - - batch_input_images = self.convert_images(batch_input_images) - batch_gt_labels = self.convert_images(batch_gt_labels) - reconstructed_frame = self.convert_images(reconstructed_frame) # Log images to TensorBoard with self.writer.as_default(): - tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1) + tf.summary.image("Input Images", actual_images, step=epoch, max_outputs=1) tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1) tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3) self.writer.flush() @@ -196,7 +182,7 @@ def main(): # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) - MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined]) + MODEL.compile(loss=combined_loss_weighted_psnr, optimizer=optimizer, metrics=[psnr, ssim, combined]) # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( diff --git a/video_compression_model.py b/video_compression_model.py index 2247aa9..bda8645 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -7,28 +7,23 @@ import numpy as np import tensorflow as tf from tensorflow.keras import layers from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset -from globalVars import LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES +from globalVars import DATATYPE, LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES +if DATATYPE == tf.float16: + from tensorflow.keras.mixed_precision import Policy -#from tensorflow.keras.mixed_precision import Policy - -#policy = Policy('mixed_float16') -#tf.keras.mixed_precision.set_global_policy(policy) - -def combine_batch(frame, crf, speed, include_controls=True, resize=True): - processed_frame = preprocess_frame(frame, resize) - height, width, _ = processed_frame.shape + policy = Policy('mixed_float16') + tf.keras.mixed_precision.set_global_policy(policy) - combined = [processed_frame] - if include_controls: - crf_array = np.full((height, width, 1), crf) - speed_array = np.full((height, width, 1), speed) - combined.extend([crf_array, speed_array]) - - return np.concatenate(combined, axis=-1) +def is_black(frame, threshold=10): + """Check if a frame is mostly black.""" + return np.mean(frame) < threshold +def combine_batch(frame, resize=True): + return preprocess_frame(frame, resize) + def frame_generator(videos, max_frames=None): base_dir = "test_data/validation/" @@ -44,13 +39,17 @@ def frame_generator(videos, max_frames=None): if not ret_compressed or not ret_uncompressed: break + # Skip black frames + if is_black(compressed_frame) or is_black(uncompressed_frame): + continue + CRF = scale_crf(video["crf"]) SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"])) - validation = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) - training = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) + validation_image = combine_batch(compressed_frame) + training_image = combine_batch(uncompressed_frame) - yield training, validation + yield ({'image': training_image, 'CRF': CRF, 'Speed': SPEED}, validation_image) frame_count += 1 if max_frames is not None and frame_count >= max_frames: @@ -60,62 +59,120 @@ def frame_generator(videos, max_frames=None): cap_uncompressed.release() - def create_dataset(videos, batch_size, max_frames=None): # Determine the output signature by processing a single video to obtain its shape video_generator_instance = frame_generator(videos, max_frames) sample_uncompressed, sample_compressed = next(video_generator_instance) + output_signature = ( - tf.TensorSpec(shape=tf.shape(sample_uncompressed), dtype=tf.float32), - tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=tf.float32) + { + 'image': tf.TensorSpec(shape=tf.shape(sample_uncompressed['image']), dtype=DATATYPE), + 'CRF': tf.TensorSpec(shape=(), dtype=DATATYPE), + 'Speed': tf.TensorSpec(shape=(), dtype=DATATYPE), + }, + tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=DATATYPE) ) dataset = tf.data.Dataset.from_generator( - lambda: frame_generator(videos, max_frames), # Include max_frames argument through lambda + lambda: frame_generator(videos, max_frames), output_signature=output_signature ) - dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) #.prefetch(tf.data.experimental.AUTOTUNE) + dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) return dataset +class SeparableTranspose2D(layers.Layer): + def __init__(self, filters, kernel_size, strides=(1, 1), padding='same', **kwargs): + super(SeparableTranspose2D, self).__init__(**kwargs) + self.filters = filters + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + + # Use UpSampling2D for resizing + self.upsample = layers.UpSampling2D(size=strides) + + # Depthwise convolution + self.depthwise_conv = layers.DepthwiseConv2D(kernel_size=kernel_size, padding=padding) + + # Pointwise convolution + self.pointwise_conv = layers.Conv2D(filters, kernel_size=(1, 1), padding=padding) + + def call(self, inputs): + x = self.upsample(inputs) + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x + + class VideoCompressionModel(tf.keras.Model): def __init__(self): super(VideoCompressionModel, self).__init__() - input_shape = (None, None, NUM_COLOUR_CHANNELS + 2) + input_shape = (None, None, NUM_COLOUR_CHANNELS) # Encoder part of the model self.encoder = tf.keras.Sequential([ layers.InputLayer(input_shape=input_shape), - layers.Conv2D(32, (3, 3), padding='same'), + layers.SeparableConv2D(64, (3, 3), padding='same'), layers.LeakyReLU(), layers.MaxPooling2D((2, 2), padding='same'), - layers.Dropout(0.4), - layers.SeparableConv2D(16, (3, 3), padding='same'), + layers.SeparableConv2D(128, (3, 3), padding='same'), layers.LeakyReLU(), layers.MaxPooling2D((2, 2), padding='same'), - layers.Dropout(0.4), ]) - # Decoder part of the model using Transposed Convolutions for upsampling + # Fully connected layers for processing CRF and Speed + self.dense_crf_speed = tf.keras.Sequential([ + layers.Dense(64, activation='relu'), + layers.Dense(128, activation='relu'), + ]) + + # Decoder part of the model self.decoder = tf.keras.Sequential([ - layers.Conv2DTranspose(16, (3, 3), padding='same'), + SeparableTranspose2D(128, (3, 3), padding='same'), layers.LeakyReLU(), - layers.Dropout(0.4), - layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same'), + SeparableTranspose2D(64, (3, 3), padding='same'), layers.LeakyReLU(), - layers.Dropout(0.4), + layers.UpSampling2D((2, 2)), layers.UpSampling2D((2, 2)), layers.Conv2D(NUM_COLOUR_CHANNELS, (3, 3), padding='same', activation='sigmoid') ]) def call(self, inputs): - #print(f"Input: {inputs.shape}") - encoded = self.encoder(inputs) - #print(f"encoded: {encoded.shape}") - decoded = self.decoder(encoded) - #print(f"decoded: {decoded.shape}") + # Extract the image, CRF, and Speed values from the inputs dictionary + image = inputs['image'] + crf = inputs['CRF'] + speed = inputs['Speed'] + + # CRF and Speed are 1D tensors with shape [batch_size] + # Concatenate them to create a [batch_size, 2] tensor + crf_speed_vector = tf.concat([tf.expand_dims(crf, -1), tf.expand_dims(speed, -1)], axis=-1) + + # Process the combined crf_speed_vector through your dense layers + # This will produce a tensor with shape [batch_size, 128] + crf_speed_features = self.dense_crf_speed(crf_speed_vector) + + # Reshape the tensor to match spatial dimensions + # New shape: [batch_size, 1, 1, 128] + crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128]) + + # Tile the tensor to match spatial dimensions of encoded tensor + # Tiled shape: [batch_size, 90, 160, 128] + crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1]) + + # Pass the image through the encoder + encoded = self.encoder(image) + + # Concatenate the encoded tensor with the crf_speed_features tensor + combined_features = tf.concat([encoded, crf_speed_features], axis=-1) + + # Pass the combined features through the decoder + decoded = self.decoder(combined_features) + return decoded + +