diff --git a/.gitignore b/.gitignore index a85e841..172e4d8 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,5 @@ !video_compression_model.py !global_train.py !log.py -!test_data/training.json -!test_data/validation.json \ No newline at end of file +!test_data/training/training.json +!test_data/validation/validation.json \ No newline at end of file diff --git a/DeepEncode.py b/DeepEncode.py index e529318..3d77d11 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -3,14 +3,14 @@ import tensorflow as tf import numpy as np import cv2 -from video_compression_model import VideoCompressionModel +from video_compression_model import PRESET_SPEED_CATEGORIES, VideoCompressionModel # Constants -CHUNK_SIZE = 10 # Adjust based on available memory and video resolution +CHUNK_SIZE = 24 # Adjust based on available memory and video resolution COMPRESSED_VIDEO_FILE = 'compressed_video.avi' MAX_FRAMES = 0 # Limit the number of frames processed -CRF = 25.0 # Example CRF value -PRESET_SPEED = 4 # Index for "fast" in our defined list +CRF = 24.0 # Example CRF value +PRESET_SPEED = "veryslow" # Index for "fast" in our defined list # Load the trained model model = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel}) @@ -42,6 +42,7 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value): compressed_frame = model.predict({ "compressed_frame": uncompressed_frame, + "uncompressed_frame": uncompressed_frame, "crf": crf_array, "preset_speed": preset_speed_array }) @@ -49,7 +50,7 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value): display_frame = np.clip(cv2.cvtColor(compressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) cv2.imshow("comp", display_frame) - cv2.waitKey(10) + cv2.waitKey(1) return compressed_frame[0] @@ -70,7 +71,7 @@ if MAX_FRAMES != 0 and total_frames > MAX_FRAMES: for i in range(total_frames): uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i) - compressed_frame = predict_frame(uncompressed_frame, model, CRF, PRESET_SPEED) + compressed_frame = predict_frame(uncompressed_frame, model, CRF, PRESET_SPEED_CATEGORIES.index(PRESET_SPEED)) compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8) compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR) diff --git a/test_data/training.json b/test_data/training/training.json similarity index 62% rename from test_data/training.json rename to test_data/training/training.json index 208bb1e..90a5622 100644 --- a/test_data/training.json +++ b/test_data/training/training.json @@ -1,73 +1,73 @@ [ { "video_file": "x264_crf-51_preset-ultrafast.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "ultrafast" }, { "video_file": "x264_crf-16_preset-veryslow.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 16, "preset_speed": "veryslow" }, { "video_file": "x264_crf-18_preset-ultrafast.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 18, "preset_speed": "ultrafast" }, { "video_file": "x264_crf-18_preset-veryslow.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 18, "preset_speed": "veryslow" }, { "video_file": "x264_crf-50_preset-veryslow.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 50, "preset_speed": "veryslow" }, { "video_file": "x264_crf-51_preset-fast.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "fast" }, { "video_file": "x264_crf-51_preset-faster.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "faster" }, { "video_file": "x264_crf-51_preset-medium.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "medium" }, { "video_file": "x264_crf-51_preset-slow.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "slow" }, { "video_file": "x264_crf-51_preset-slower.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "slower" }, { "video_file": "x264_crf-51_preset-superfast.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "superfast" }, { "video_file": "x264_crf-51_preset-veryfast.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", + "uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv", "crf": 51, "preset_speed": "veryfast" } diff --git a/test_data/validation.json b/test_data/validation.json deleted file mode 100644 index b8912d3..0000000 --- a/test_data/validation.json +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "video_file": "x264_crf-16_preset-veryslow.mkv", - "uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv", - "crf": 16, - "preset_speed": "veryslow" - } -] diff --git a/test_data/validation/validation.json b/test_data/validation/validation.json new file mode 100644 index 0000000..7f938f2 --- /dev/null +++ b/test_data/validation/validation.json @@ -0,0 +1,9 @@ +[ + + { + "video_file": "Scene2_x264_crf-51_preset-veryslow.mkv", + "uncompressed_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv", + "crf": 51, + "preset_speed": "veryslow" + } +] diff --git a/train_model.py b/train_model.py index 17dae8f..6a92fd7 100644 --- a/train_model.py +++ b/train_model.py @@ -18,10 +18,12 @@ from global_train import LOGGER BATCH_SIZE = 4 EPOCHS = 100 LEARNING_RATE = 0.000001 -TRAIN_SAMPLES = 50 +TRAIN_SAMPLES = 100 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 +WIDTH = 638 +HEIGHT = 360 def load_video_metadata(list_path): LOGGER.trace(f"Entering: load_video_metadata({list_path})") @@ -67,11 +69,11 @@ def load_video_samples(list_path, samples=TRAIN_SAMPLES): compressed_frames, uncompressed_frames = [], [] try: - cap = cv2.VideoCapture(os.path.join("test_data/", video_file)) - cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", uncompressed_video_file)) + cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file)) + cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file)) if not cap.isOpened() or not cap_uncompressed.isOpened(): - raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}") + raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}") for _ in range(frames_per_video): ret, frame_compressed = cap.read() @@ -79,6 +81,14 @@ def load_video_samples(list_path, samples=TRAIN_SAMPLES): if not ret or not ret_uncompressed: continue + + # Check frame dimensions and resize if necessary + if frame.shape[:2] != (WIDTH, HEIGHT): + LOGGER.warn(f"Resizing video: {video_file}") + frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) + if frame_compressed.shape[:2] != (WIDTH, HEIGHT): + LOGGER.warn(f"Resizing video: {uncompressed_video_file}") + frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB) @@ -149,8 +159,8 @@ def main(): # Load training and validation samples LOGGER.debug("Loading training and validation samples.") - training_samples = load_video_samples("test_data/training.json") - validation_samples = load_video_samples("test_data/validation.json", args.training_samples // 2) + training_samples = load_video_samples("test_data/training/training.json") + validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2) train_generator = VideoDataGenerator(training_samples, args.batch_size) val_generator = VideoDataGenerator(validation_samples, args.batch_size)