update
This commit is contained in:
parent
3ea1568ad3
commit
7787d0584e
4 changed files with 35 additions and 25 deletions
|
@ -13,17 +13,18 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
|
|||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
def combine_batch(frame, crf, speed):
|
||||
# Preprocess the compressed frame (target)
|
||||
processed_frame = preprocess_frame(frame)
|
||||
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
||||
processed_frame = preprocess_frame(frame, resize)
|
||||
height, width, _ = processed_frame.shape
|
||||
|
||||
crf_array = np.full((HEIGHT, WIDTH, 1), crf) # Note the added dimension
|
||||
speed_array = np.full((HEIGHT, WIDTH, 1), speed) # Note the added dimension
|
||||
combined = [processed_frame]
|
||||
if include_controls:
|
||||
crf_array = np.full((height, width, 1), crf)
|
||||
speed_array = np.full((height, width, 1), speed)
|
||||
combined.extend([crf_array, speed_array])
|
||||
|
||||
# Combine the frames with the CRF and SPEED images
|
||||
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
|
||||
|
||||
return combined
|
||||
return np.concatenate(combined, axis=-1)
|
||||
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
# Infinite loop to keep generating batches
|
||||
|
@ -53,8 +54,10 @@ def data_generator(videos, batch_size):
|
|||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED)
|
||||
# Target data
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
|
||||
# Input data
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
|
||||
# Append processed frames to batches
|
||||
|
@ -80,8 +83,8 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
super(VideoCompressionModel, self).__init__()
|
||||
LOGGER.debug("Initializing VideoCompressionModel.")
|
||||
|
||||
# Input shape (includes channels for edges and histogram)
|
||||
input_shape_with_histogram = (HEIGHT, WIDTH, NUM_COLOUR_CHANNELS + 2)
|
||||
# Input shape (includes channels for CRF and SPEED_PRESET)
|
||||
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
|
||||
|
||||
# Encoder part of the model
|
||||
self.encoder = tf.keras.Sequential([
|
||||
|
|
Reference in a new issue