diff --git a/globalVars.py b/globalVars.py index b57cc47..b78664a 100644 --- a/globalVars.py +++ b/globalVars.py @@ -32,11 +32,11 @@ def load_video_metadata(list_path): - list: List of dictionaries, each containing video details. """ - LOGGER.trace(f"Entering: load_video_metadata({list_path})") + #LOGGER.trace(f"Entering: load_video_metadata({list_path})") try: with open(list_path, "r") as json_file: file = json.load(json_file) - LOGGER.trace(f"load_video_metadata returning: {file}") + #LOGGER.trace(f"load_video_metadata returning: {file}") return file except FileNotFoundError: LOGGER.error(f"Metadata file {list_path} not found.") diff --git a/train_model.py b/train_model.py index e465752..c58758e 100644 --- a/train_model.py +++ b/train_model.py @@ -66,14 +66,21 @@ class ImageLoggingCallback(Callback): return np.stack(converted, axis=0) def on_epoch_end(self, epoch, logs=None): - itter = iter(self.validation_dataset) - random_idx = np.random.randint(0, BATCH_SIZE) + random_idx = np.random.randint(0, MAX_FRAMES - 1) + + validation_data = None + dataset_size = 0 # to keep track of the dataset size # Loop through the dataset until the chosen index for i, data in enumerate(self.validation_dataset): if i == random_idx: validation_data = data break + dataset_size += 1 + + if validation_data is None: + print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.") + validation_data = data # assigning the last data seen in the loop to validation_data batch_input_images, batch_gt_labels = validation_data @@ -83,6 +90,10 @@ class ImageLoggingCallback(Callback): reconstructed_frame = MODEL.predict(validation_data[0]) reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8) + # Save the reconstructed frame to the specified folder + reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png") + cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example + batch_input_images = self.convert_images(batch_input_images) batch_gt_labels = self.convert_images(batch_gt_labels) reconstructed_frame = self.convert_images(reconstructed_frame) diff --git a/video_compression_model.py b/video_compression_model.py index f3fb201..6e7c9a2 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf from tensorflow.keras import layers from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset -from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH +from globalVars import LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES #from tensorflow.keras.mixed_precision import Policy @@ -107,13 +107,20 @@ class VideoCompressionModel(tf.keras.Model): layers.Conv2DTranspose(64, (3, 3), dilation_rate=2, padding='same'), # Using Dilated Convolution #layers.BatchNormalization(), layers.LeakyReLU(), - # Use Sub-Pixel Convolutional Layer - layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 16, (3, 3), padding='same'), # 16 times the number of color channels - layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=4)) # Sub-Pixel Convolutional Layer with block_size=4 + # First Sub-Pixel Convolutional Layer + layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 4, (3, 3), padding='same'), # 4 times the number of color channels for first upscaling by 2 + layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=2)), # Sub-Pixel Convolutional Layer with block_size=2 + # Second Sub-Pixel Convolutional Layer + layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 4, (3, 3), padding='same'), # 4 times the number of color channels for second upscaling by 2 + layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=2)), # Sub-Pixel Convolutional Layer with block_size=2 + layers.Activation('sigmoid') ]) + def call(self, inputs): + #print(f"Input: {inputs.shape}") encoded = self.encoder(inputs) - return self.decoder(encoded) - - + #print(f"encoded: {encoded.shape}") + decoded = self.decoder(encoded) + #print(f"decoded: {decoded.shape}") + return decoded