test
This commit is contained in:
parent
8c5001166d
commit
5085c87300
3 changed files with 96 additions and 173 deletions
|
@ -2,7 +2,6 @@ import tensorflow as tf
|
|||
|
||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_FRAMES = 5 # Number of consecutive frames in a sequence
|
||||
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
||||
|
||||
#policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
|
@ -13,7 +12,6 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
super(VideoCompressionModel, self).__init__()
|
||||
|
||||
self.NUM_CHANNELS = NUM_CHANNELS
|
||||
self.NUM_FRAMES = NUM_FRAMES
|
||||
|
||||
# Regularization
|
||||
self.regularizer = tf.keras.regularizers.l2(regularization_factor)
|
||||
|
@ -23,21 +21,24 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
|
||||
# Encoder layers
|
||||
self.encoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16), kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.MaxPooling3D((2, 2, 2)),
|
||||
tf.keras.layers.ZeroPadding2D(padding=((1, 1), (1, 1))), # Padding to preserve spatial dimensions
|
||||
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.MaxPooling2D((2, 2)),
|
||||
# Add more encoder layers as needed
|
||||
])
|
||||
|
||||
# Decoder layers
|
||||
self.decoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.UpSampling3D((2, 2, 2)),
|
||||
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
# Add more decoder layers as needed
|
||||
tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer) # Output layer for video frames
|
||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer), # Output layer for video frames
|
||||
tf.keras.layers.Cropping2D(cropping=((1, 1), (1, 1))) # Adjust cropping to ensure dimensions match
|
||||
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
frames = inputs["frames"]
|
||||
frame = inputs["frame"]
|
||||
crf = tf.expand_dims(inputs["crf"], -1)
|
||||
preset_speed = inputs["preset_speed"]
|
||||
|
||||
|
@ -46,15 +47,15 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
preset_embedding = tf.keras.layers.Flatten()(preset_embedding)
|
||||
|
||||
# Concatenate crf and preset_embedding to frames
|
||||
frames_shape = tf.shape(frames)
|
||||
repeated_crf = tf.tile(tf.reshape(crf, (-1, 1, 1, 1, 1)), [1, frames_shape[1], frames_shape[2], frames_shape[3], 1])
|
||||
repeated_preset = tf.tile(tf.reshape(preset_embedding, (-1, 1, 1, 1, 16)), [1, frames_shape[1], frames_shape[2], frames_shape[3], 1])
|
||||
frame_shape = tf.shape(frame)
|
||||
repeated_crf = tf.tile(tf.reshape(crf, (-1, 1, 1, 1)), [1, frame_shape[1], frame_shape[2], 1])
|
||||
repeated_preset = tf.tile(tf.reshape(preset_embedding, (-1, 1, 1, 16)), [1, frame_shape[1], frame_shape[2], 1])
|
||||
|
||||
frames = tf.concat([frames, repeated_crf, repeated_preset], axis=-1)
|
||||
frame = tf.concat([tf.cast(frame, tf.float32), repeated_crf, repeated_preset], axis=-1)
|
||||
|
||||
# Encoding the video frames
|
||||
compressed_representation = self.encoder(frames)
|
||||
# Encoding the frame
|
||||
compressed_representation = self.encoder(frame)
|
||||
|
||||
# Decoding to generate compressed video frames
|
||||
reconstructed_frames = self.decoder(compressed_representation)
|
||||
return reconstructed_frames[:,-1,:,:,:]
|
||||
# Decoding to generate compressed frame
|
||||
reconstructed_frame = self.decoder(compressed_representation)
|
||||
return reconstructed_frame
|
||||
|
|
Reference in a new issue