Optimised the training pipeline

This commit is contained in:
Jordon Brooks 2023-08-16 23:33:16 +01:00
parent 15d8e57da5
commit 6a2449c5cd
2 changed files with 27 additions and 21 deletions

View file

@ -13,6 +13,15 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
#policy = Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
def combine_batch(frame, crf_array, speed_array):
# Preprocess the compressed frame (target)
processed_frame = preprocess_frame(frame)
# Combine the frames with the CRF and SPEED images
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
return combined
def data_generator(videos, batch_size):
# Infinite loop to keep generating batches
while True:
@ -24,7 +33,15 @@ def data_generator(videos, batch_size):
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
CRF = video_details["crf"] / 51
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]) / NUM_PRESET_SPEEDS
# Create images with the CRF and SPEED values, filling extra channels
compressed_crf_array = np.full((HEIGHT, WIDTH, 1), CRF) # Note the added dimension
compressed_speed_array = np.full((HEIGHT, WIDTH, 1), SPEED) # Note the added dimension
# Create images with the CRF and SPEED values, filling extra channels
uncompressed_crf_array = np.full((HEIGHT, WIDTH, 1), 0) # Note the added dimension
uncompressed_speed_array = np.full((HEIGHT, WIDTH, 1), PRESET_SPEED_CATEGORIES.index("veryslow") / NUM_PRESET_SPEEDS) # Note the added dimension
# Open the video files
cap_compressed = cv2.VideoCapture(video_path)
@ -41,15 +58,13 @@ def data_generator(videos, batch_size):
if not ret_compressed or not ret_uncompressed:
break
# Preprocess the compressed frame (target)
compressed_frame = preprocess_frame(compressed_frame, CRF, SPEED)
compressed_combined = combine_batch(compressed_frame, compressed_crf_array, compressed_speed_array)
# Preprocess the uncompressed frame (input)
uncompressed_frame = preprocess_frame(uncompressed_frame, 0, PRESET_SPEED_CATEGORIES.index("veryslow")) # Modify if different preprocessing is needed for target frames
uncompressed_combined = combine_batch(uncompressed_frame, uncompressed_crf_array, uncompressed_speed_array)
# Append processed frames to batches
compressed_frame_batch.append(compressed_frame)
uncompressed_frame_batch.append(uncompressed_frame)
compressed_frame_batch.append(compressed_combined)
uncompressed_frame_batch.append(uncompressed_combined)
# If batch is complete, yield it
if len(compressed_frame_batch) == batch_size: