semi-working

This commit is contained in:
Jordon Brooks 2023-08-13 20:48:00 +01:00
parent e7af02cb4f
commit 54fa90247a
4 changed files with 38 additions and 11 deletions

View file

@ -19,8 +19,9 @@ def data_generator(videos, batch_size):
# Iterate over each video
for video_details in videos:
# Get the paths for compressed and original (uncompressed) video files
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
uncompressed_video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["original_video_file"])
base_dir = os.path.dirname("test_data/validation/validation.json")
video_path = os.path.join(base_dir, video_details["compressed_video_file"])
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
CRF = video_details["crf"] / 51
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
@ -87,12 +88,14 @@ class VideoCompressionModel(tf.keras.Model):
tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')
tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS + 2, (3, 3), activation='sigmoid', padding='same')
])
def call(self, inputs):
# Encode the input
#print("Input shape:", inputs.shape)
encoded = self.encoder(inputs)
# Decode the encoded representation
#print("Encoded shape:", encoded.shape)
decoded = self.decoder(encoded)
#print("Decoded shape:", decoded.shape)
return decoded