update
This commit is contained in:
parent
4d29fffba1
commit
8df4df7972
3 changed files with 51 additions and 20 deletions
|
@ -78,7 +78,7 @@ def create_dataset(videos, batch_size, max_frames=None):
|
|||
output_signature=output_signature
|
||||
)
|
||||
|
||||
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1)
|
||||
dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
|
@ -160,13 +160,16 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
# New shape: [batch_size, 1, 1, 128]
|
||||
crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128])
|
||||
|
||||
# Tile the tensor to match spatial dimensions of encoded tensor
|
||||
# Tiled shape: [batch_size, 90, 160, 128]
|
||||
crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1])
|
||||
|
||||
# Pass the image through the encoder
|
||||
encoded = self.encoder(image)
|
||||
|
||||
# Dynamically compute the spatial dimensions of the encoded tensor
|
||||
encoded_shape = tf.shape(encoded)
|
||||
height, width = encoded_shape[1], encoded_shape[2]
|
||||
|
||||
# Tile the crf_speed_features tensor to match the spatial dimensions of the encoded tensor
|
||||
crf_speed_features = tf.tile(crf_speed_features, [1, height, width, 1])
|
||||
|
||||
# Concatenate the encoded tensor with the crf_speed_features tensor
|
||||
combined_features = tf.concat([encoded, crf_speed_features], axis=-1)
|
||||
|
||||
|
@ -176,3 +179,4 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
return decoded
|
||||
|
||||
|
||||
|
||||
|
|
Reference in a new issue