update model

This commit is contained in:
Jordon Brooks 2023-08-25 01:54:22 +01:00
parent 98df94b180
commit 9cecaeb9d6
No known key found for this signature in database
GPG key ID: 83964894E5D98D57

View file

@ -89,31 +89,26 @@ class VideoCompressionModel(tf.keras.Model):
# Encoder part of the model
self.encoder = tf.keras.Sequential([
layers.InputLayer(input_shape=input_shape),
layers.Conv2D(64, (3, 3), padding='same'),
#layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), padding='same'),
layers.LeakyReLU(),
layers.MaxPooling2D((2, 2), padding='same'),
layers.SeparableConv2D(32, (3, 3), padding='same'), # Using Separable Convolution
#layers.BatchNormalization(),
layers.Dropout(0.4),
layers.SeparableConv2D(16, (3, 3), padding='same'),
layers.LeakyReLU(),
layers.MaxPooling2D((2, 2), padding='same')
layers.MaxPooling2D((2, 2), padding='same'),
layers.Dropout(0.4),
])
# Decoder part of the model
# Decoder part of the model using Transposed Convolutions for upsampling
self.decoder = tf.keras.Sequential([
layers.Conv2DTranspose(32, (3, 3), padding='same'),
#layers.BatchNormalization(),
layers.Conv2DTranspose(16, (3, 3), padding='same'),
layers.LeakyReLU(),
layers.Conv2DTranspose(64, (3, 3), dilation_rate=2, padding='same'), # Using Dilated Convolution
#layers.BatchNormalization(),
layers.Dropout(0.4),
layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same'),
layers.LeakyReLU(),
# First Sub-Pixel Convolutional Layer
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 4, (3, 3), padding='same'), # 4 times the number of color channels for first upscaling by 2
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=2)), # Sub-Pixel Convolutional Layer with block_size=2
# Second Sub-Pixel Convolutional Layer
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 4, (3, 3), padding='same'), # 4 times the number of color channels for second upscaling by 2
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=2)), # Sub-Pixel Convolutional Layer with block_size=2
layers.Activation('sigmoid')
layers.Dropout(0.4),
layers.UpSampling2D((2, 2)),
layers.Conv2D(NUM_COLOUR_CHANNELS, (3, 3), padding='same', activation='sigmoid')
])