Added GC
The data set will now process frames from ALL videos
This commit is contained in:
parent
b95e3558ff
commit
ba6c132c67
1 changed files with 37 additions and 44 deletions
|
@ -1,5 +1,6 @@
|
||||||
# video_compression_model.py
|
# video_compression_model.py
|
||||||
|
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -18,6 +19,7 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
||||||
height, width, _ = processed_frame.shape
|
height, width, _ = processed_frame.shape
|
||||||
|
|
||||||
combined = [processed_frame]
|
combined = [processed_frame]
|
||||||
|
|
||||||
if include_controls:
|
if include_controls:
|
||||||
crf_array = np.full((height, width, 1), crf)
|
crf_array = np.full((height, width, 1), crf)
|
||||||
speed_array = np.full((height, width, 1), speed)
|
speed_array = np.full((height, width, 1), speed)
|
||||||
|
@ -27,56 +29,52 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
||||||
|
|
||||||
|
|
||||||
def data_generator(videos, batch_size):
|
def data_generator(videos, batch_size):
|
||||||
# Infinite loop to keep generating batches
|
base_dir = os.path.dirname("test_data/validation/validation.json")
|
||||||
while True:
|
|
||||||
# Iterate over each video
|
|
||||||
for video_details in videos:
|
|
||||||
# Get the paths for compressed and original (uncompressed) video files
|
|
||||||
base_dir = os.path.dirname("test_data/validation/validation.json")
|
|
||||||
video_path = os.path.join(base_dir, video_details["compressed_video_file"])
|
|
||||||
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
|
|
||||||
|
|
||||||
CRF = scale_crf(video_details["crf"])
|
|
||||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]))
|
|
||||||
|
|
||||||
# Open the video files
|
|
||||||
cap_compressed = cv2.VideoCapture(video_path)
|
|
||||||
cap_uncompressed = cv2.VideoCapture(uncompressed_video_path)
|
|
||||||
|
|
||||||
# Lists to store the processed frames
|
|
||||||
compressed_frame_batch = [] # Input data (Target)
|
|
||||||
uncompressed_frame_batch = [] # Target data (Training)
|
|
||||||
|
|
||||||
# Read and process frames from both videos
|
while True:
|
||||||
while cap_compressed.isOpened() and cap_uncompressed.isOpened():
|
# Lists to store the processed frames
|
||||||
|
compressed_frame_batch = [] # Input data (Target)
|
||||||
|
uncompressed_frame_batch = [] # Target data (Training)
|
||||||
|
|
||||||
|
# Get a list of video capture objects for all videos
|
||||||
|
caps_compressed = [cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) for video in videos]
|
||||||
|
caps_uncompressed = [cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) for video in videos]
|
||||||
|
|
||||||
|
# As long as any video can provide frames, keep running
|
||||||
|
while any(cap.isOpened() for cap in caps_compressed):
|
||||||
|
for idx, (cap_compressed, cap_uncompressed) in enumerate(zip(caps_compressed, caps_uncompressed)):
|
||||||
|
#print(f"(Video Change) Processing video {idx}") # Print statement to indicate video change
|
||||||
|
|
||||||
ret_compressed, compressed_frame = cap_compressed.read()
|
ret_compressed, compressed_frame = cap_compressed.read()
|
||||||
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
|
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
|
||||||
|
|
||||||
if not ret_compressed or not ret_uncompressed:
|
if not ret_compressed or not ret_uncompressed:
|
||||||
break
|
continue
|
||||||
|
|
||||||
# Target data
|
CRF = scale_crf(videos[idx]["crf"])
|
||||||
|
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(videos[idx]["preset_speed"]))
|
||||||
|
|
||||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||||
|
|
||||||
# Input data
|
|
||||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||||
|
|
||||||
# Append processed frames to batches
|
|
||||||
compressed_frame_batch.append(compressed_combined)
|
compressed_frame_batch.append(compressed_combined)
|
||||||
uncompressed_frame_batch.append(uncompressed_combined)
|
uncompressed_frame_batch.append(uncompressed_combined)
|
||||||
|
|
||||||
# If batch is complete, yield it
|
|
||||||
if len(compressed_frame_batch) == batch_size:
|
if len(compressed_frame_batch) == batch_size:
|
||||||
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch)) # Yielding Training and Target data
|
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
|
||||||
compressed_frame_batch = []
|
compressed_frame_batch.clear()
|
||||||
uncompressed_frame_batch = []
|
uncompressed_frame_batch.clear()
|
||||||
|
|
||||||
# Release video files
|
# Close all video captures at the end
|
||||||
cap_compressed.release()
|
for cap in caps_compressed + caps_uncompressed:
|
||||||
cap_uncompressed.release()
|
cap.release()
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
# If there are frames left that don't fill a whole batch, send them anyway
|
||||||
|
if len(compressed_frame_batch) > 0:
|
||||||
|
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
|
||||||
|
|
||||||
# If there are frames left that don't fill a whole batch, send them anyway
|
|
||||||
if len(compressed_frame_batch) > 0:
|
|
||||||
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
|
|
||||||
|
|
||||||
class VideoCompressionModel(tf.keras.Model):
|
class VideoCompressionModel(tf.keras.Model):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -105,10 +103,5 @@ class VideoCompressionModel(tf.keras.Model):
|
||||||
])
|
])
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
#print("Input shape:", inputs.shape)
|
return self.decoder(self.encoder(inputs))
|
||||||
encoded = self.encoder(inputs)
|
|
||||||
#print("Encoded shape:", encoded.shape)
|
|
||||||
decoded = self.decoder(encoded)
|
|
||||||
#print("Decoded shape:", decoded.shape)
|
|
||||||
return decoded
|
|
||||||
|
|
||||||
|
|
Reference in a new issue