diff --git a/DeepEncode.py b/DeepEncode.py index 9bc605a..e7f3fc6 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -7,13 +7,13 @@ import numpy as np os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import tensorflow as tf -from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim +from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, scale_crf, scale_speed_preset, ssim from globalVars import PRESET_SPEED_CATEGORIES, clear_screen from video_compression_model import VideoCompressionModel, combine_batch # Constants COMPRESSED_VIDEO_FILE = 'compressed_video.avi' -MAX_FRAMES = 200 # Limit the number of frames processed +MAX_FRAMES = 0 # Limit the number of frames processed CRF = 10 SPEED = "ultrafast" MODEL_PATH = 'models/model.tf' @@ -33,7 +33,7 @@ def parse_arguments(): parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model') parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file') parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen') - parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False) + parser.add_argument('--keep_black_bars', action='store_false', help='Keep black bars from the video', default=True) args = parser.parse_args() @@ -95,18 +95,35 @@ def load_frame_from_video(video_file, frame_num): def predict_frame(uncompressed_frame, model, crf, speed): + # Scale the CRF and Speed values scaled_crf = scale_crf(crf) scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed)) - frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False) - compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0] - return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8) + + # Preprocess the frame + frame = combine_batch(uncompressed_frame, resize=False) + + # Predict using the model + inputs = { + 'image': np.expand_dims(frame, axis=0), + 'CRF': np.array([scaled_crf]), + 'Speed': np.array([scaled_speed]) + } + compressed_frame = model.predict(inputs)[0] + + # Post-process the output frame + return np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8) + def main(): - model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss}) + model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss, 'combined_loss_weighted_psnr': combined_loss_weighted_psnr}) cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE) - total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES) + if MAX_FRAMES > 0: + total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES) + else: + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS)) cap.release() @@ -127,6 +144,7 @@ def main(): compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED) compressed_frame = cv2.resize(compressed_frame, (width, height)) + compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR) out.write(compressed_frame) diff --git a/train_model.py b/train_model.py index d652d2b..4abb3fb 100644 --- a/train_model.py +++ b/train_model.py @@ -42,8 +42,8 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_vid # Constants BATCH_SIZE = 25 -EPOCHS = 100 -LEARNING_RATE = 0.0001 +EPOCHS = 1000 +LEARNING_RATE = 0.005 DECAY_STEPS = 160 DECAY_RATE = 0.9 MODEL_SAVE_FILE = "models/model.tf" @@ -66,8 +66,10 @@ class ImageLoggingCallback(Callback): return np.stack(converted, axis=0) def on_epoch_end(self, epoch, logs=None): + # where total_batches is the number of batches in the validation dataset + skip_batches = np.random.randint(0, 100) # Get the first batch from the validation dataset - validation_data = next(iter(self.validation_dataset.take(1))) + validation_data = next(iter(self.validation_dataset.skip(skip_batches).take(1))) # Extract the inputs from the batch_input_images dictionary actual_images = validation_data[0]['image'] @@ -82,7 +84,7 @@ class ImageLoggingCallback(Callback): # Save the reconstructed frame to the specified folder reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png") - cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example + cv2.imwrite(reconstructed_path, cv2.cvtColor(reconstructed_frame[0], cv2.COLOR_RGB2BGR)) # Saving only the first image as an example # Log images to TensorBoard with self.writer.as_default(): @@ -145,13 +147,13 @@ def main(): # Load all video metadata all_videos = load_video_metadata("test_data/validation/validation.json") - tf.random.set_seed(RANDOM_SEED) + #tf.random.set_seed(RANDOM_SEED) # Shuffle the data using the specified seed random.shuffle(all_videos, random.seed(RANDOM_SEED)) # Split into training and validation - split_index = int(0.6 * len(all_videos)) + split_index = int(0.7 * len(all_videos)) training_videos = all_videos[:split_index] validation_videos = all_videos[split_index:] @@ -166,7 +168,14 @@ def main(): if args.continue_training: - MODEL = tf.keras.models.load_model(args.continue_training) + MODEL = tf.keras.models.load_model(args.continue_training, custom_objects={ + 'VideoCompressionModel': VideoCompressionModel, + 'psnr': psnr, + 'ssim': ssim, + 'combined': combined, + 'combined_loss': combined_loss, + 'combined_loss_weighted_psnr': combined_loss_weighted_psnr + }) else: MODEL = VideoCompressionModel() diff --git a/video_compression_model.py b/video_compression_model.py index bda8645..aa30efb 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -78,7 +78,7 @@ def create_dataset(videos, batch_size, max_frames=None): output_signature=output_signature ) - dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) + dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE) return dataset @@ -160,13 +160,16 @@ class VideoCompressionModel(tf.keras.Model): # New shape: [batch_size, 1, 1, 128] crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128]) - # Tile the tensor to match spatial dimensions of encoded tensor - # Tiled shape: [batch_size, 90, 160, 128] - crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1]) - # Pass the image through the encoder encoded = self.encoder(image) + # Dynamically compute the spatial dimensions of the encoded tensor + encoded_shape = tf.shape(encoded) + height, width = encoded_shape[1], encoded_shape[2] + + # Tile the crf_speed_features tensor to match the spatial dimensions of the encoded tensor + crf_speed_features = tf.tile(crf_speed_features, [1, height, width, 1]) + # Concatenate the encoded tensor with the crf_speed_features tensor combined_features = tf.concat([encoded, crf_speed_features], axis=-1) @@ -176,3 +179,4 @@ class VideoCompressionModel(tf.keras.Model): return decoded +