Added model_summary function

This commit is contained in:
Jordon Brooks 2023-07-30 12:57:40 +01:00
parent dea59068fb
commit 5bca78e687

View file

@ -41,6 +41,14 @@ class VideoCompressionModel(tf.keras.Model):
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
])
def model_summary(self):
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
x3 = tf.keras.Input(shape=(1,), name='crf')
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
def call(self, inputs):
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']