This commit is contained in:
Jordon Brooks 2023-08-17 01:57:53 +01:00
parent 3ea1568ad3
commit 7787d0584e
4 changed files with 35 additions and 25 deletions

View file

@ -13,17 +13,18 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
#policy = Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
def combine_batch(frame, crf, speed):
# Preprocess the compressed frame (target)
processed_frame = preprocess_frame(frame)
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
processed_frame = preprocess_frame(frame, resize)
height, width, _ = processed_frame.shape
crf_array = np.full((HEIGHT, WIDTH, 1), crf) # Note the added dimension
speed_array = np.full((HEIGHT, WIDTH, 1), speed) # Note the added dimension
combined = [processed_frame]
if include_controls:
crf_array = np.full((height, width, 1), crf)
speed_array = np.full((height, width, 1), speed)
combined.extend([crf_array, speed_array])
# Combine the frames with the CRF and SPEED images
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
return combined
return np.concatenate(combined, axis=-1)
def data_generator(videos, batch_size):
# Infinite loop to keep generating batches
@ -53,8 +54,10 @@ def data_generator(videos, batch_size):
if not ret_compressed or not ret_uncompressed:
break
compressed_combined = combine_batch(compressed_frame, CRF, SPEED)
# Target data
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
# Input data
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
# Append processed frames to batches
@ -80,8 +83,8 @@ class VideoCompressionModel(tf.keras.Model):
super(VideoCompressionModel, self).__init__()
LOGGER.debug("Initializing VideoCompressionModel.")
# Input shape (includes channels for edges and histogram)
input_shape_with_histogram = (HEIGHT, WIDTH, NUM_COLOUR_CHANNELS + 2)
# Input shape (includes channels for CRF and SPEED_PRESET)
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
# Encoder part of the model
self.encoder = tf.keras.Sequential([