Optimised the training pipeline

This commit is contained in:
Jordon Brooks 2023-08-16 23:33:16 +01:00
parent 15d8e57da5
commit 6a2449c5cd
2 changed files with 27 additions and 21 deletions

View file

@ -48,23 +48,14 @@ def psnr(y_true, y_pred):
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
def preprocess_frame(frame, crf, speed):
def preprocess_frame(frame):
#Preprocesses a single frame, cropping it if needed
# Check frame dimensions and resize if necessary
if frame.shape[:2] != (HEIGHT, WIDTH):
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
# Scale frame to [0, 1]
compressed_frame = frame / 255.0
# Scale CRF and SPEED to [0, 1] (assuming they are within known bounds)
crf_scaled = crf / 51
speed_scaled = speed / NUM_PRESET_SPEEDS
# Create images with the CRF and SPEED values, filling extra channels
crf_image = np.full((HEIGHT, WIDTH, 1), crf_scaled) # Note the added dimension
speed_image = np.full((HEIGHT, WIDTH, 1), speed_scaled) # Note the added dimension
# Combine the frames with the CRF and SPEED images
combined_frame = np.concatenate([compressed_frame, crf_image, speed_image], axis=-1)
return combined_frame
return compressed_frame

View file

@ -13,6 +13,15 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
#policy = Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
def combine_batch(frame, crf_array, speed_array):
# Preprocess the compressed frame (target)
processed_frame = preprocess_frame(frame)
# Combine the frames with the CRF and SPEED images
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
return combined
def data_generator(videos, batch_size):
# Infinite loop to keep generating batches
while True:
@ -24,7 +33,15 @@ def data_generator(videos, batch_size):
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
CRF = video_details["crf"] / 51
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]) / NUM_PRESET_SPEEDS
# Create images with the CRF and SPEED values, filling extra channels
compressed_crf_array = np.full((HEIGHT, WIDTH, 1), CRF) # Note the added dimension
compressed_speed_array = np.full((HEIGHT, WIDTH, 1), SPEED) # Note the added dimension
# Create images with the CRF and SPEED values, filling extra channels
uncompressed_crf_array = np.full((HEIGHT, WIDTH, 1), 0) # Note the added dimension
uncompressed_speed_array = np.full((HEIGHT, WIDTH, 1), PRESET_SPEED_CATEGORIES.index("veryslow") / NUM_PRESET_SPEEDS) # Note the added dimension
# Open the video files
cap_compressed = cv2.VideoCapture(video_path)
@ -41,15 +58,13 @@ def data_generator(videos, batch_size):
if not ret_compressed or not ret_uncompressed:
break
# Preprocess the compressed frame (target)
compressed_frame = preprocess_frame(compressed_frame, CRF, SPEED)
compressed_combined = combine_batch(compressed_frame, compressed_crf_array, compressed_speed_array)
# Preprocess the uncompressed frame (input)
uncompressed_frame = preprocess_frame(uncompressed_frame, 0, PRESET_SPEED_CATEGORIES.index("veryslow")) # Modify if different preprocessing is needed for target frames
uncompressed_combined = combine_batch(uncompressed_frame, uncompressed_crf_array, uncompressed_speed_array)
# Append processed frames to batches
compressed_frame_batch.append(compressed_frame)
uncompressed_frame_batch.append(uncompressed_frame)
compressed_frame_batch.append(compressed_combined)
uncompressed_frame_batch.append(uncompressed_combined)
# If batch is complete, yield it
if len(compressed_frame_batch) == batch_size: