diff --git a/video_compression_model.py b/video_compression_model.py index 6ea48f0..795be62 100644 --- a/video_compression_model.py +++ b/video_compression_model.py @@ -40,6 +40,14 @@ class VideoCompressionModel(tf.keras.Model): tf.keras.layers.Dropout(0.3), tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames ]) + + def model_summary(self): + x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame') + x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame') + x3 = tf.keras.Input(shape=(1,), name='crf') + x4 = tf.keras.Input(shape=(1,), name='preset_speed') + return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary() + def call(self, inputs): uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']