53 lines
2.4 KiB
Python
53 lines
2.4 KiB
Python
import tensorflow as tf
|
|
|
|
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
|
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
|
NUM_FRAMES = 5 # Number of consecutive frames in a sequence
|
|
|
|
class VideoCompressionModel(tf.keras.Model):
|
|
def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5):
|
|
super(VideoCompressionModel, self).__init__()
|
|
|
|
self.NUM_CHANNELS = NUM_CHANNELS
|
|
self.NUM_FRAMES = NUM_FRAMES
|
|
|
|
# Embedding layer for preset_speed
|
|
self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16)
|
|
|
|
# Encoder layers
|
|
self.encoder = tf.keras.Sequential([
|
|
tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16)), # Notice the adjusted channel number
|
|
tf.keras.layers.MaxPooling3D((2, 2, 2)),
|
|
# Add more encoder layers as needed
|
|
])
|
|
|
|
# Decoder layers
|
|
self.decoder = tf.keras.Sequential([
|
|
tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.UpSampling3D((2, 2, 2)),
|
|
# Add more decoder layers as needed
|
|
tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
|
])
|
|
|
|
def call(self, inputs):
|
|
frames = inputs["frames"]
|
|
crf = tf.expand_dims(inputs["crf"], -1)
|
|
preset_speed = inputs["preset_speed"]
|
|
|
|
# Convert preset_speed to embeddings
|
|
preset_embedding = self.preset_embedding(preset_speed)
|
|
preset_embedding = tf.keras.layers.Flatten()(preset_embedding)
|
|
|
|
# Concatenate crf and preset_embedding to frames
|
|
frames_shape = tf.shape(frames)
|
|
repeated_crf = tf.tile(tf.reshape(crf, (-1, 1, 1, 1, 1)), [1, frames_shape[1], frames_shape[2], frames_shape[3], 1])
|
|
repeated_preset = tf.tile(tf.reshape(preset_embedding, (-1, 1, 1, 1, 16)), [1, frames_shape[1], frames_shape[2], frames_shape[3], 1])
|
|
|
|
frames = tf.concat([frames, repeated_crf, repeated_preset], axis=-1)
|
|
|
|
# Encoding the video frames
|
|
compressed_representation = self.encoder(frames)
|
|
|
|
# Decoding to generate compressed video frames
|
|
reconstructed_frames = self.decoder(compressed_representation)
|
|
return reconstructed_frames[:,-1,:,:,:]
|