model now uses tensorflow dataset generator
This commit is contained in:
parent
ba6c132c67
commit
f06d3ae504
2 changed files with 84 additions and 48 deletions
|
@ -28,52 +28,86 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
|||
return np.concatenate(combined, axis=-1)
|
||||
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
def process_video(video):
|
||||
base_dir = os.path.dirname("test_data/validation/validation.json")
|
||||
|
||||
cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"]))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"]))
|
||||
|
||||
compressed_frames = []
|
||||
uncompressed_frames = []
|
||||
|
||||
while True:
|
||||
# Lists to store the processed frames
|
||||
compressed_frame_batch = [] # Input data (Target)
|
||||
uncompressed_frame_batch = [] # Target data (Training)
|
||||
ret_compressed, compressed_frame = cap_compressed.read()
|
||||
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
|
||||
|
||||
# Get a list of video capture objects for all videos
|
||||
caps_compressed = [cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) for video in videos]
|
||||
caps_uncompressed = [cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) for video in videos]
|
||||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
# As long as any video can provide frames, keep running
|
||||
while any(cap.isOpened() for cap in caps_compressed):
|
||||
for idx, (cap_compressed, cap_uncompressed) in enumerate(zip(caps_compressed, caps_uncompressed)):
|
||||
#print(f"(Video Change) Processing video {idx}") # Print statement to indicate video change
|
||||
|
||||
ret_compressed, compressed_frame = cap_compressed.read()
|
||||
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
|
||||
CRF = scale_crf(video["crf"])
|
||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||
|
||||
if not ret_compressed or not ret_uncompressed:
|
||||
continue
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
|
||||
CRF = scale_crf(videos[idx]["crf"])
|
||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(videos[idx]["preset_speed"]))
|
||||
compressed_frames.append(compressed_combined)
|
||||
uncompressed_frames.append(uncompressed_combined)
|
||||
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
cap_compressed.release()
|
||||
cap_uncompressed.release()
|
||||
|
||||
compressed_frame_batch.append(compressed_combined)
|
||||
uncompressed_frame_batch.append(uncompressed_combined)
|
||||
return uncompressed_frames, compressed_frames
|
||||
|
||||
if len(compressed_frame_batch) == batch_size:
|
||||
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
|
||||
compressed_frame_batch.clear()
|
||||
uncompressed_frame_batch.clear()
|
||||
|
||||
# Close all video captures at the end
|
||||
for cap in caps_compressed + caps_uncompressed:
|
||||
cap.release()
|
||||
def frame_generator(videos, max_frames=None):
|
||||
base_dir = "test_data/validation/"
|
||||
for video in videos:
|
||||
cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"]))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"]))
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
frame_count = 0
|
||||
while True:
|
||||
ret_compressed, compressed_frame = cap_compressed.read()
|
||||
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
|
||||
|
||||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
CRF = scale_crf(video["crf"])
|
||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
|
||||
yield uncompressed_combined, compressed_combined
|
||||
|
||||
frame_count += 1
|
||||
if max_frames is not None and frame_count >= max_frames:
|
||||
break
|
||||
|
||||
cap_compressed.release()
|
||||
cap_uncompressed.release()
|
||||
|
||||
|
||||
|
||||
def create_dataset(videos, batch_size, max_frames=None):
|
||||
# Determine the output signature by processing a single video to obtain its shape
|
||||
video_generator_instance = frame_generator(videos, max_frames)
|
||||
sample_uncompressed, sample_compressed = next(video_generator_instance)
|
||||
output_signature = (
|
||||
tf.TensorSpec(shape=tf.shape(sample_uncompressed), dtype=tf.float32),
|
||||
tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=tf.float32)
|
||||
)
|
||||
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
lambda: frame_generator(videos, max_frames), # Include max_frames argument through lambda
|
||||
output_signature=output_signature
|
||||
)
|
||||
|
||||
dataset = dataset.shuffle(100).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
# If there are frames left that don't fill a whole batch, send them anyway
|
||||
if len(compressed_frame_batch) > 0:
|
||||
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
|
||||
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
|
|
Reference in a new issue