sequenced based

This commit is contained in:
Jordon Brooks 2023-07-24 23:56:46 +01:00
parent 80c5f2216d
commit d0f0b21cb5
3 changed files with 150 additions and 61 deletions

View file

@ -1,7 +1,7 @@
import tensorflow as tf
import numpy as np
import cv2
from video_compression_model import VideoCompressionModel
from video_compression_model import NUM_FRAMES, PRESET_SPEED_CATEGORIES, VideoCompressionModel
# Constants
NUM_CHANNELS = 3
@ -10,7 +10,7 @@ NUM_CHANNELS = 3
model = tf.keras.models.load_model('models/model.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel})
# Step 3: Load the uncompressed video
UNCOMPRESSED_VIDEO_FILE = 'test_data/test_video.mkv'
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
def load_frames_from_video(video_file, num_frames = 0):
print("Extracting video frames...")
@ -32,19 +32,40 @@ def load_frames_from_video(video_file, num_frames = 0):
print("Extraction Complete")
return frames
uncompressed_frames = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, 200)
if len(uncompressed_frames) == 0 or None:
uncompressed_frames = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, 100)
if not uncompressed_frames:
print("IO ERROR!")
exit()
uncompressed_frames = np.array(uncompressed_frames) / 255.0
if len(uncompressed_frames) == 0 or None:
print("np.array ERROR!")
exit()
# Generate sequences of frames for prediction
uncompressed_frame_sequences = []
for i in range(len(uncompressed_frames) - NUM_FRAMES + 1):
sequence = uncompressed_frames[i:i+NUM_FRAMES]
uncompressed_frame_sequences.append(sequence)
uncompressed_frame_sequences = np.array(uncompressed_frame_sequences)
#for frame in uncompressed_frames:
# cv2.imshow('Frame', frame)
# cv2.waitKey(50) # Display each frame for 1 second
# Step 4: Compress the video frames using the loaded model
compressed_frames = model.predict(uncompressed_frames)
crf_values = np.full((len(uncompressed_frame_sequences), 1), 25, dtype=np.float32) # Added dtype argument
preset_speed_index = PRESET_SPEED_CATEGORIES.index("fast")
preset_speed_values = np.full((len(uncompressed_frame_sequences), 1), preset_speed_index, dtype=np.float32) # Added dtype argument
compressed_frame_sequences = model.predict({"frames": uncompressed_frame_sequences, "crf": crf_values, "preset_speed": preset_speed_values})
# We'll use the last frame of each sequence as the compressed frame
#compressed_frames = compressed_frame_sequences[:, -1]
#for frame in compressed_frame_sequences:
# cv2.imshow('Compressed Frame', frame)
# cv2.waitKey(50)
# Step 5: Save the compressed video frames
COMPRESSED_VIDEO_FILE = 'compressed_video.mkv'
@ -60,5 +81,5 @@ def save_frames_as_video(frames, video_file):
out.write(frame)
out.release()
save_frames_as_video(compressed_frames, COMPRESSED_VIDEO_FILE)
save_frames_as_video(compressed_frame_sequences, COMPRESSED_VIDEO_FILE)
print("Compression completed.")