Optimised the training pipeline
This commit is contained in:
parent
15d8e57da5
commit
6a2449c5cd
2 changed files with 27 additions and 21 deletions
|
@ -13,6 +13,15 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
|
|||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
def combine_batch(frame, crf_array, speed_array):
|
||||
# Preprocess the compressed frame (target)
|
||||
processed_frame = preprocess_frame(frame)
|
||||
|
||||
# Combine the frames with the CRF and SPEED images
|
||||
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
|
||||
|
||||
return combined
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
# Infinite loop to keep generating batches
|
||||
while True:
|
||||
|
@ -24,7 +33,15 @@ def data_generator(videos, batch_size):
|
|||
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
|
||||
|
||||
CRF = video_details["crf"] / 51
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]) / NUM_PRESET_SPEEDS
|
||||
|
||||
# Create images with the CRF and SPEED values, filling extra channels
|
||||
compressed_crf_array = np.full((HEIGHT, WIDTH, 1), CRF) # Note the added dimension
|
||||
compressed_speed_array = np.full((HEIGHT, WIDTH, 1), SPEED) # Note the added dimension
|
||||
|
||||
# Create images with the CRF and SPEED values, filling extra channels
|
||||
uncompressed_crf_array = np.full((HEIGHT, WIDTH, 1), 0) # Note the added dimension
|
||||
uncompressed_speed_array = np.full((HEIGHT, WIDTH, 1), PRESET_SPEED_CATEGORIES.index("veryslow") / NUM_PRESET_SPEEDS) # Note the added dimension
|
||||
|
||||
# Open the video files
|
||||
cap_compressed = cv2.VideoCapture(video_path)
|
||||
|
@ -41,15 +58,13 @@ def data_generator(videos, batch_size):
|
|||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
# Preprocess the compressed frame (target)
|
||||
compressed_frame = preprocess_frame(compressed_frame, CRF, SPEED)
|
||||
compressed_combined = combine_batch(compressed_frame, compressed_crf_array, compressed_speed_array)
|
||||
|
||||
# Preprocess the uncompressed frame (input)
|
||||
uncompressed_frame = preprocess_frame(uncompressed_frame, 0, PRESET_SPEED_CATEGORIES.index("veryslow")) # Modify if different preprocessing is needed for target frames
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, uncompressed_crf_array, uncompressed_speed_array)
|
||||
|
||||
# Append processed frames to batches
|
||||
compressed_frame_batch.append(compressed_frame)
|
||||
uncompressed_frame_batch.append(uncompressed_frame)
|
||||
compressed_frame_batch.append(compressed_combined)
|
||||
uncompressed_frame_batch.append(uncompressed_combined)
|
||||
|
||||
# If batch is complete, yield it
|
||||
if len(compressed_frame_batch) == batch_size:
|
||||
|
|
Reference in a new issue