Optimised the training pipeline
This commit is contained in:
parent
15d8e57da5
commit
6a2449c5cd
2 changed files with 27 additions and 21 deletions
|
@ -48,23 +48,14 @@ def psnr(y_true, y_pred):
|
|||
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
||||
|
||||
|
||||
def preprocess_frame(frame, crf, speed):
|
||||
def preprocess_frame(frame):
|
||||
#Preprocesses a single frame, cropping it if needed
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if frame.shape[:2] != (HEIGHT, WIDTH):
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Scale frame to [0, 1]
|
||||
compressed_frame = frame / 255.0
|
||||
|
||||
# Scale CRF and SPEED to [0, 1] (assuming they are within known bounds)
|
||||
crf_scaled = crf / 51
|
||||
speed_scaled = speed / NUM_PRESET_SPEEDS
|
||||
|
||||
# Create images with the CRF and SPEED values, filling extra channels
|
||||
crf_image = np.full((HEIGHT, WIDTH, 1), crf_scaled) # Note the added dimension
|
||||
speed_image = np.full((HEIGHT, WIDTH, 1), speed_scaled) # Note the added dimension
|
||||
|
||||
# Combine the frames with the CRF and SPEED images
|
||||
combined_frame = np.concatenate([compressed_frame, crf_image, speed_image], axis=-1)
|
||||
|
||||
return combined_frame
|
||||
return compressed_frame
|
||||
|
|
|
@ -13,6 +13,15 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
|
|||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
def combine_batch(frame, crf_array, speed_array):
|
||||
# Preprocess the compressed frame (target)
|
||||
processed_frame = preprocess_frame(frame)
|
||||
|
||||
# Combine the frames with the CRF and SPEED images
|
||||
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
|
||||
|
||||
return combined
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
# Infinite loop to keep generating batches
|
||||
while True:
|
||||
|
@ -24,7 +33,15 @@ def data_generator(videos, batch_size):
|
|||
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
|
||||
|
||||
CRF = video_details["crf"] / 51
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]) / NUM_PRESET_SPEEDS
|
||||
|
||||
# Create images with the CRF and SPEED values, filling extra channels
|
||||
compressed_crf_array = np.full((HEIGHT, WIDTH, 1), CRF) # Note the added dimension
|
||||
compressed_speed_array = np.full((HEIGHT, WIDTH, 1), SPEED) # Note the added dimension
|
||||
|
||||
# Create images with the CRF and SPEED values, filling extra channels
|
||||
uncompressed_crf_array = np.full((HEIGHT, WIDTH, 1), 0) # Note the added dimension
|
||||
uncompressed_speed_array = np.full((HEIGHT, WIDTH, 1), PRESET_SPEED_CATEGORIES.index("veryslow") / NUM_PRESET_SPEEDS) # Note the added dimension
|
||||
|
||||
# Open the video files
|
||||
cap_compressed = cv2.VideoCapture(video_path)
|
||||
|
@ -41,15 +58,13 @@ def data_generator(videos, batch_size):
|
|||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
# Preprocess the compressed frame (target)
|
||||
compressed_frame = preprocess_frame(compressed_frame, CRF, SPEED)
|
||||
compressed_combined = combine_batch(compressed_frame, compressed_crf_array, compressed_speed_array)
|
||||
|
||||
# Preprocess the uncompressed frame (input)
|
||||
uncompressed_frame = preprocess_frame(uncompressed_frame, 0, PRESET_SPEED_CATEGORIES.index("veryslow")) # Modify if different preprocessing is needed for target frames
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, uncompressed_crf_array, uncompressed_speed_array)
|
||||
|
||||
# Append processed frames to batches
|
||||
compressed_frame_batch.append(compressed_frame)
|
||||
uncompressed_frame_batch.append(uncompressed_frame)
|
||||
compressed_frame_batch.append(compressed_combined)
|
||||
uncompressed_frame_batch.append(uncompressed_combined)
|
||||
|
||||
# If batch is complete, yield it
|
||||
if len(compressed_frame_batch) == batch_size:
|
||||
|
|
Reference in a new issue