This commit is contained in:
Jordon Brooks 2023-09-10 19:05:52 +01:00
parent 4d29fffba1
commit 8df4df7972
No known key found for this signature in database
GPG key ID: 83964894E5D98D57
3 changed files with 51 additions and 20 deletions

View file

@ -78,7 +78,7 @@ def create_dataset(videos, batch_size, max_frames=None):
output_signature=output_signature
)
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1)
dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
return dataset
@ -160,13 +160,16 @@ class VideoCompressionModel(tf.keras.Model):
# New shape: [batch_size, 1, 1, 128]
crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128])
# Tile the tensor to match spatial dimensions of encoded tensor
# Tiled shape: [batch_size, 90, 160, 128]
crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1])
# Pass the image through the encoder
encoded = self.encoder(image)
# Dynamically compute the spatial dimensions of the encoded tensor
encoded_shape = tf.shape(encoded)
height, width = encoded_shape[1], encoded_shape[2]
# Tile the crf_speed_features tensor to match the spatial dimensions of the encoded tensor
crf_speed_features = tf.tile(crf_speed_features, [1, height, width, 1])
# Concatenate the encoded tensor with the crf_speed_features tensor
combined_features = tf.concat([encoded, crf_speed_features], axis=-1)
@ -176,3 +179,4 @@ class VideoCompressionModel(tf.keras.Model):
return decoded