85 lines
3 KiB
Python
85 lines
3 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
import cv2
|
|
from video_compression_model import NUM_FRAMES, PRESET_SPEED_CATEGORIES, VideoCompressionModel
|
|
|
|
# Constants
|
|
NUM_CHANNELS = 3
|
|
|
|
# Step 2: Load the trained model
|
|
model = tf.keras.models.load_model('models/model.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
|
|
|
# Step 3: Load the uncompressed video
|
|
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
|
|
|
def load_frames_from_video(video_file, num_frames = 0):
|
|
print("Extracting video frames...")
|
|
cap = cv2.VideoCapture(video_file)
|
|
frames = []
|
|
count = 0
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
print("Max frames from file reached")
|
|
break
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frames.append(frame)
|
|
count += 1
|
|
if num_frames == 0 or count >= num_frames:
|
|
print("Max Frames wanted reached: ", num_frames)
|
|
break
|
|
cap.release()
|
|
print("Extraction Complete")
|
|
return frames
|
|
|
|
uncompressed_frames = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, 100)
|
|
if not uncompressed_frames:
|
|
print("IO ERROR!")
|
|
exit()
|
|
|
|
uncompressed_frames = np.array(uncompressed_frames) / 255.0
|
|
|
|
# Generate sequences of frames for prediction
|
|
uncompressed_frame_sequences = []
|
|
for i in range(len(uncompressed_frames) - NUM_FRAMES + 1):
|
|
sequence = uncompressed_frames[i:i+NUM_FRAMES]
|
|
uncompressed_frame_sequences.append(sequence)
|
|
uncompressed_frame_sequences = np.array(uncompressed_frame_sequences)
|
|
|
|
#for frame in uncompressed_frames:
|
|
# cv2.imshow('Frame', frame)
|
|
# cv2.waitKey(50) # Display each frame for 1 second
|
|
|
|
|
|
# Step 4: Compress the video frames using the loaded model
|
|
crf_values = np.full((len(uncompressed_frame_sequences), 1), 25, dtype=np.float32) # Added dtype argument
|
|
|
|
preset_speed_index = PRESET_SPEED_CATEGORIES.index("fast")
|
|
preset_speed_values = np.full((len(uncompressed_frame_sequences), 1), preset_speed_index, dtype=np.float32) # Added dtype argument
|
|
|
|
compressed_frame_sequences = model.predict({"frames": uncompressed_frame_sequences, "crf": crf_values, "preset_speed": preset_speed_values})
|
|
|
|
# We'll use the last frame of each sequence as the compressed frame
|
|
#compressed_frames = compressed_frame_sequences[:, -1]
|
|
|
|
#for frame in compressed_frame_sequences:
|
|
# cv2.imshow('Compressed Frame', frame)
|
|
# cv2.waitKey(50)
|
|
|
|
|
|
# Step 5: Save the compressed video frames
|
|
COMPRESSED_VIDEO_FILE = 'compressed_video.mkv'
|
|
|
|
def save_frames_as_video(frames, video_file):
|
|
print("Saving video frames...")
|
|
height, width = frames[0].shape[:2]
|
|
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
|
out = cv2.VideoWriter(video_file, fourcc, 24.0, (width, height))
|
|
for frame in frames:
|
|
frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8)
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
out.write(frame)
|
|
out.release()
|
|
|
|
save_frames_as_video(compressed_frame_sequences, COMPRESSED_VIDEO_FILE)
|
|
print("Compression completed.")
|