optimisation
This commit is contained in:
parent
d0f0b21cb5
commit
b97293d7ca
3 changed files with 112 additions and 97 deletions
|
@ -3,30 +3,37 @@ import tensorflow as tf
|
|||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_FRAMES = 5 # Number of consecutive frames in a sequence
|
||||
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
||||
|
||||
#policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5):
|
||||
def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5, regularization_factor=1e-4):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
|
||||
|
||||
self.NUM_CHANNELS = NUM_CHANNELS
|
||||
self.NUM_FRAMES = NUM_FRAMES
|
||||
|
||||
# Regularization
|
||||
self.regularizer = tf.keras.regularizers.l2(regularization_factor)
|
||||
|
||||
# Embedding layer for preset_speed
|
||||
self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16)
|
||||
self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16, embeddings_regularizer=self.regularizer)
|
||||
|
||||
# Encoder layers
|
||||
self.encoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16)), # Notice the adjusted channel number
|
||||
tf.keras.layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', input_shape=(None, None, None, NUM_CHANNELS + 1 + 16), kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.MaxPooling3D((2, 2, 2)),
|
||||
# Add more encoder layers as needed
|
||||
])
|
||||
|
||||
# Decoder layers
|
||||
self.decoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.Conv3DTranspose(32, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.UpSampling3D((2, 2, 2)),
|
||||
# Add more decoder layers as needed
|
||||
tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
||||
tf.keras.layers.Conv3D(NUM_CHANNELS, (3, 3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer) # Output layer for video frames
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
|
|
Reference in a new issue