updated
This commit is contained in:
parent
db43239b3d
commit
98df94b180
3 changed files with 29 additions and 11 deletions
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
from tensorflow.keras import layers
|
||||
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
|
||||
from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH
|
||||
from globalVars import LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES
|
||||
|
||||
|
||||
#from tensorflow.keras.mixed_precision import Policy
|
||||
|
@ -107,13 +107,20 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
layers.Conv2DTranspose(64, (3, 3), dilation_rate=2, padding='same'), # Using Dilated Convolution
|
||||
#layers.BatchNormalization(),
|
||||
layers.LeakyReLU(),
|
||||
# Use Sub-Pixel Convolutional Layer
|
||||
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 16, (3, 3), padding='same'), # 16 times the number of color channels
|
||||
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=4)) # Sub-Pixel Convolutional Layer with block_size=4
|
||||
# First Sub-Pixel Convolutional Layer
|
||||
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 4, (3, 3), padding='same'), # 4 times the number of color channels for first upscaling by 2
|
||||
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=2)), # Sub-Pixel Convolutional Layer with block_size=2
|
||||
# Second Sub-Pixel Convolutional Layer
|
||||
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 4, (3, 3), padding='same'), # 4 times the number of color channels for second upscaling by 2
|
||||
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=2)), # Sub-Pixel Convolutional Layer with block_size=2
|
||||
layers.Activation('sigmoid')
|
||||
])
|
||||
|
||||
|
||||
def call(self, inputs):
|
||||
#print(f"Input: {inputs.shape}")
|
||||
encoded = self.encoder(inputs)
|
||||
return self.decoder(encoded)
|
||||
|
||||
|
||||
#print(f"encoded: {encoded.shape}")
|
||||
decoded = self.decoder(encoded)
|
||||
#print(f"decoded: {decoded.shape}")
|
||||
return decoded
|
||||
|
|
Reference in a new issue