diff --git a/DeepEncode.py b/DeepEncode.py index 283da48..a488562 100644 --- a/DeepEncode.py +++ b/DeepEncode.py @@ -4,82 +4,88 @@ import cv2 from video_compression_model import NUM_FRAMES, PRESET_SPEED_CATEGORIES, VideoCompressionModel # Constants -NUM_CHANNELS = 3 +MAX_FRAMES = 24 +CHUNK_SIZE = 24 # Adjust based on available memory and video resolution +COMPRESSED_VIDEO_FILE = 'compressed_video.mkv' + # Step 2: Load the trained model -model = tf.keras.models.load_model('models/model.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel}) +model = tf.keras.models.load_model('models/model_differencing.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel}) # Step 3: Load the uncompressed video UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv' -def load_frames_from_video(video_file, num_frames = 0): - print("Extracting video frames...") +def load_frames_from_video(video_file, start_frame=0, num_frames=CHUNK_SIZE): cap = cv2.VideoCapture(video_file) frames = [] - count = 0 - while True: + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + + for _ in range(num_frames): ret, frame = cap.read() if not ret: - print("Max frames from file reached") break - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 # Normalize and convert to float32 frames.append(frame) - count += 1 - if num_frames == 0 or count >= num_frames: - print("Max Frames wanted reached: ", num_frames) - break cap.release() - print("Extraction Complete") return frames -uncompressed_frames = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, 100) -if not uncompressed_frames: - print("IO ERROR!") - exit() +def predict_in_chunks(uncompressed_frames, model, crf_values, preset_speed_values): + num_sequences = len(uncompressed_frames) - NUM_FRAMES + 1 + compressed_frames = [] -uncompressed_frames = np.array(uncompressed_frames) / 255.0 + for frame in uncompressed_frames: + cv2.imshow("frame", frame) + cv2.waitKey(50) -# Generate sequences of frames for prediction -uncompressed_frame_sequences = [] -for i in range(len(uncompressed_frames) - NUM_FRAMES + 1): - sequence = uncompressed_frames[i:i+NUM_FRAMES] - uncompressed_frame_sequences.append(sequence) -uncompressed_frame_sequences = np.array(uncompressed_frame_sequences) + for start in range(0, num_sequences, CHUNK_SIZE): + end = min(start + CHUNK_SIZE, num_sequences) + frame_chunk = uncompressed_frames[start:end + NUM_FRAMES - 1] + crf_chunk = crf_values[start:end] + speed_chunk = preset_speed_values[start:end] -#for frame in uncompressed_frames: -# cv2.imshow('Frame', frame) -# cv2.waitKey(50) # Display each frame for 1 second + frame_sequences = [] + for i in range(len(frame_chunk) - NUM_FRAMES + 1): + sequence = frame_chunk[i:i + NUM_FRAMES] + frame_sequences.append(sequence) + + frame_sequences = np.array(frame_sequences) + compressed_chunk = model.predict({"frames": frame_sequences, "crf": crf_chunk, "preset_speed": speed_chunk}) + compressed_frames.extend(compressed_chunk) + + return compressed_frames -# Step 4: Compress the video frames using the loaded model -crf_values = np.full((len(uncompressed_frame_sequences), 1), 25, dtype=np.float32) # Added dtype argument - -preset_speed_index = PRESET_SPEED_CATEGORIES.index("fast") -preset_speed_values = np.full((len(uncompressed_frame_sequences), 1), preset_speed_index, dtype=np.float32) # Added dtype argument - -compressed_frame_sequences = model.predict({"frames": uncompressed_frame_sequences, "crf": crf_values, "preset_speed": preset_speed_values}) - -# We'll use the last frame of each sequence as the compressed frame -#compressed_frames = compressed_frame_sequences[:, -1] - -#for frame in compressed_frame_sequences: -# cv2.imshow('Compressed Frame', frame) -# cv2.waitKey(50) - - -# Step 5: Save the compressed video frames -COMPRESSED_VIDEO_FILE = 'compressed_video.mkv' - -def save_frames_as_video(frames, video_file): - print("Saving video frames...") - height, width = frames[0].shape[:2] - fourcc = cv2.VideoWriter_fourcc(*'XVID') - out = cv2.VideoWriter(video_file, fourcc, 24.0, (width, height)) +def save_frames_chunk(frames, video_writer): for frame in frames: frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8) frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - out.write(frame) - out.release() + video_writer.write(frame) -save_frames_as_video(compressed_frame_sequences, COMPRESSED_VIDEO_FILE) + +cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE) +total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + +if MAX_FRAMES != 0 and total_frames > MAX_FRAMES: + total_frames = MAX_FRAMES + +cap.release() + +crf_values = np.full((CHUNK_SIZE + NUM_FRAMES - 1, 1), 25, dtype=np.float32) # Chunk size + look-ahead frames +preset_speed_index = PRESET_SPEED_CATEGORIES.index("fast") +preset_speed_values = np.full((CHUNK_SIZE + NUM_FRAMES - 1, 1), preset_speed_index, dtype=np.float32) + +out = None # Video writer instance +for i in range(0, total_frames, CHUNK_SIZE): + uncompressed_frames_chunk = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, start_frame=i) + compressed_frames_chunk = predict_in_chunks(uncompressed_frames_chunk, model, crf_values, preset_speed_values) + + # Initialize video writer if it's the first chunk + if out is None: + height, width = compressed_frames_chunk[0].shape[:2] + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, 24.0, (width, height)) + + save_frames_chunk(compressed_frames_chunk, out) + +out.release() print("Compression completed.") diff --git a/train_model.py b/train_model.py index d410e1c..ce0d2a5 100644 --- a/train_model.py +++ b/train_model.py @@ -3,22 +3,19 @@ import json import tensorflow as tf import numpy as np import cv2 -from video_compression_model import NUM_FRAMES, VideoCompressionModel, PRESET_SPEED_CATEGORIES +from video_compression_model import NUM_CHANNELS, NUM_FRAMES, VideoCompressionModel, PRESET_SPEED_CATEGORIES +from tensorflow.keras.callbacks import EarlyStopping # Constants -NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels) -BATCH_SIZE = 16 # Batch size used during training -EPOCHS = 1 # Number of training epochs -TRAIN_SAMPLES = 1 # number of frames to extract - -# Step 1: Data Preparation +BATCH_SIZE = 16 +EPOCHS = 1 +TRAIN_SAMPLES = 1 def load_list(list_path): with open(list_path, "r") as json_file: video_details_list = json.load(json_file) return video_details_list -# Update load_frames_from_video function to resize frames def load_frames_from_video(video_file, num_frames): print("Extracting video frames...") cap = cv2.VideoCapture(video_file) @@ -29,7 +26,6 @@ def load_frames_from_video(video_file, num_frames): if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - #frame = cv2.resize(frame, (target_width, target_height)) frames.append(frame) count += 1 if count >= num_frames: @@ -46,7 +42,6 @@ def save_model(model, file): model.save(os.path.join("models/", file)) print("Model saved successfully!") -# Update load_video_from_list function to provide target_width and target_height def load_video_from_list(list_path): details_list = load_list(list_path) all_frames = [] @@ -57,8 +52,6 @@ def load_video_from_list(list_path): PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) video_details['preset_speed'] = PRESET_SPEED - # Update load_frames_from_video calls with target_width and target_height - #train_frames, w, h = load_frames_from_video(os.path.join("test_data/", VIDEO_FILE), TRAIN_SAMPLES, target_width, target_height) train_frames, w, h = load_frames_from_video(os.path.join("test_data/", VIDEO_FILE), NUM_FRAMES * TRAIN_SAMPLES) all_frames.extend(train_frames) all_details.append({ @@ -72,52 +65,61 @@ def load_video_from_list(list_path): return all_details def generate_frame_sequences(frames): - # Generate sequences of frames for the model sequences = [] labels = [] - for i in range(len(frames) - NUM_FRAMES + 1): - sequence = frames[i:i+NUM_FRAMES] + for i in range(len(frames) - NUM_FRAMES + 2): + sequence = frames[i:i+NUM_FRAMES-1] sequences.append(sequence) - # Use the last frame of the sequence as the label labels.append(sequence[-1]) return np.array(sequences), np.array(labels) +def frame_difference(frames): + differences = [] + for i in range(1, len(frames)): + differences.append(cv2.absdiff(frames[i], frames[i-1])) + return differences def main(): - #target_width = 640 # Choose a fixed width for the frames - #target_height = 360 # Choose a fixed height for the frames - - all_video_details = load_video_from_list("test_data/training.json") + all_video_details_train = load_video_from_list("test_data/training.json") + all_video_details_val = load_video_from_list("test_data/validation.json") model = VideoCompressionModel(NUM_CHANNELS, NUM_FRAMES) model.compile(loss='mean_squared_error', optimizer='adam') - for video_details in all_video_details: - train_frames = video_details["frames"] - val_frames = train_frames.copy() # For simplicity, using the same frames for validation + early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True) - train_frames = preprocess(train_frames) - val_frames = preprocess(val_frames) + for video_details_train, video_details_val in zip(all_video_details_train, all_video_details_val): + train_frames = video_details_train["frames"] + val_frames = video_details_val["frames"] - train_sequences, train_labels = generate_frame_sequences(train_frames) - val_sequences, val_labels = generate_frame_sequences(val_frames) + train_differences = frame_difference(preprocess(train_frames)) + val_differences = frame_difference(preprocess(val_frames)) - num_sequences = len(train_sequences) - crf_array = np.full((num_sequences, 1), video_details['crf']) - preset_speed_array = np.full((num_sequences, 1), video_details['preset_speed']) + train_sequences, train_labels = generate_frame_sequences(train_differences) + val_sequences, val_labels = generate_frame_sequences(val_differences) - print("\nTraining the model for video:", video_details["video_file"]) + num_sequences_train = len(train_sequences) + num_sequences_val = len(val_sequences) + crf_array_train = np.full((num_sequences_train, 1), video_details_train['crf']) + crf_array_val = np.full((num_sequences_val, 1), video_details_val['crf']) + preset_speed_array_train = np.full((num_sequences_train, 1), video_details_train['preset_speed']) + preset_speed_array_val = np.full((num_sequences_val, 1), video_details_val['preset_speed']) + + print(len(train_sequences)) + print(len(val_sequences)) + + print("\nTraining the model for video:", video_details_train["video_file"]) model.fit( - {"frames": train_sequences, "crf": crf_array, "preset_speed": preset_speed_array}, - train_labels, # Use train_labels as the ground truth + {"frames": train_sequences, "crf": crf_array_train, "preset_speed": preset_speed_array_train}, + train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS, - validation_data=({"frames": val_sequences, "crf": crf_array, "preset_speed": preset_speed_array}, - val_labels) # Use val_labels as the ground truth for validation + validation_data=({"frames": val_sequences, "crf": crf_array_val, "preset_speed": preset_speed_array_val}, val_labels), + callbacks=[early_stop] ) - print("\nTraining completed for video:", video_details["video_file"]) + print("\nTraining completed for video:", video_details_train["video_file"]) - save_model(model, 'model.keras') + save_model(model, 'model_differencing.keras') if __name__ == "__main__": main() diff --git a/video_compression_model.py b/video_compression_model.py index 47cc0b8..a74753b 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -3,30 +3,37 @@ import tensorflow as tf PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"] NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES) NUM_FRAMES = 5 # Number of consecutive frames in a sequence +NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels) + +#policy = tf.keras.mixed_precision.Policy('mixed_float16') +#tf.keras.mixed_precision.set_global_policy(policy) class VideoCompressionModel(tf.keras.Model): - def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5): + def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5, regularization_factor=1e-4): super(VideoCompressionModel, self).__init__() - + self.NUM_CHANNELS = NUM_CHANNELS self.NUM_FRAMES = NUM_FRAMES + # Regularization + self.regularizer = tf.keras.regularizers.l2(regularization_factor) + # Embedding layer for preset_speed - self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16) + self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16, embeddings_regularizer=self.regularizer) # Encoder layers self.encoder = tf.keras.Sequential([ - tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16)), # Notice the adjusted channel number + tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16), kernel_regularizer=self.regularizer), tf.keras.layers.MaxPooling3D((2, 2, 2)), # Add more encoder layers as needed ]) # Decoder layers self.decoder = tf.keras.Sequential([ - tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same'), + tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer), tf.keras.layers.UpSampling3D((2, 2, 2)), # Add more decoder layers as needed - tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same') # Output layer for video frames + tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer) # Output layer for video frames ]) def call(self, inputs):