From 54fa90247a4996456e03abeef38b715acc186935 Mon Sep 17 00:00:00 2001 From: Jordon Brooks Date: Sun, 13 Aug 2023 20:48:00 +0100 Subject: [PATCH] semi-working --- DeepEncode.py | 10 +++++----- test_data/validation/validation.json | 24 ++++++++++++++++++++++++ train_model.py | 2 +- video_compression_model.py | 13 ++++++++----- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/DeepEncode.py b/DeepEncode.py index 7a033e2..2f23073 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -16,7 +16,7 @@ from video_compression_model import VideoCompressionModel COMPRESSED_VIDEO_FILE = 'compressed_video.avi' MAX_FRAMES = 0 # Limit the number of frames processed CRF = 51 -SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast") +SPEED = PRESET_SPEED_CATEGORIES.index("veryslow") # Load the trained model MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel}) @@ -30,21 +30,21 @@ def load_frame_from_video(video_file, frame_num): ret, frame = cap.read() if not ret: return None - #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) cap.release() return frame def predict_frame(uncompressed_frame): - #display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) - #cv2.imshow("uncomp", uncompressed_frame) + display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) + cv2.imshow("uncomp", uncompressed_frame) frame = preprocess_frame(uncompressed_frame, CRF, SPEED) compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0] + compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR) - display_frame = np.clip(cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) + display_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8) cv2.imshow("comp", display_frame) cv2.waitKey(1) diff --git a/test_data/validation/validation.json b/test_data/validation/validation.json index 0ae23ea..bb9b68f 100644 --- a/test_data/validation/validation.json +++ b/test_data/validation/validation.json @@ -22,5 +22,29 @@ "original_video_file": "Scene5.mkv", "crf": 51, "preset_speed": "veryslow" + }, + { + "compressed_video_file": "Scene6_x264_crf-25_preset-medium.mkv", + "original_video_file": "Scene6.mkv", + "crf": 25, + "preset_speed": "medium" + }, + { + "compressed_video_file": "Scene7_x264_crf-34_preset-slower.mkv", + "original_video_file": "Scene7.mkv", + "crf": 34, + "preset_speed": "slower" + }, + { + "compressed_video_file": "Scene8_x264_crf-12_preset-faster.mkv", + "original_video_file": "Scene8.mkv", + "crf": 12, + "preset_speed": "faster" + }, + { + "compressed_video_file": "Scene9_x264_crf-15_preset-slow.mkv", + "original_video_file": "Scene9.mkv", + "crf": 15, + "preset_speed": "slow" } ] diff --git a/train_model.py b/train_model.py index e08b268..b65a966 100644 --- a/train_model.py +++ b/train_model.py @@ -16,7 +16,7 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER # Constants BATCH_SIZE = 16 EPOCHS = 100 -LEARNING_RATE = 0.001 +LEARNING_RATE = 0.000001 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 diff --git a/video_compression_model.py b/video_compression_model.py index 69f9c0a..7bfe485 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -19,8 +19,9 @@ def data_generator(videos, batch_size): # Iterate over each video for video_details in videos: # Get the paths for compressed and original (uncompressed) video files - video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"]) - uncompressed_video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["original_video_file"]) + base_dir = os.path.dirname("test_data/validation/validation.json") + video_path = os.path.join(base_dir, video_details["compressed_video_file"]) + uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"]) CRF = video_details["crf"] / 51 SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]) @@ -87,12 +88,14 @@ class VideoCompressionModel(tf.keras.Model): tf.keras.layers.UpSampling2D((2, 2)), tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), tf.keras.layers.UpSampling2D((2, 2)), - tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same') + tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS + 2, (3, 3), activation='sigmoid', padding='same') ]) def call(self, inputs): - # Encode the input + #print("Input shape:", inputs.shape) encoded = self.encoder(inputs) - # Decode the encoded representation + #print("Encoded shape:", encoded.shape) decoded = self.decoder(encoded) + #print("Decoded shape:", decoded.shape) return decoded +