This commit is contained in:
Jordon Brooks 2023-09-10 19:05:52 +01:00
parent 4d29fffba1
commit 8df4df7972
No known key found for this signature in database
GPG key ID: 83964894E5D98D57
3 changed files with 51 additions and 20 deletions

View file

@ -42,8 +42,8 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_vid
# Constants
BATCH_SIZE = 25
EPOCHS = 100
LEARNING_RATE = 0.0001
EPOCHS = 1000
LEARNING_RATE = 0.005
DECAY_STEPS = 160
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
@ -66,8 +66,10 @@ class ImageLoggingCallback(Callback):
return np.stack(converted, axis=0)
def on_epoch_end(self, epoch, logs=None):
# where total_batches is the number of batches in the validation dataset
skip_batches = np.random.randint(0, 100)
# Get the first batch from the validation dataset
validation_data = next(iter(self.validation_dataset.take(1)))
validation_data = next(iter(self.validation_dataset.skip(skip_batches).take(1)))
# Extract the inputs from the batch_input_images dictionary
actual_images = validation_data[0]['image']
@ -82,7 +84,7 @@ class ImageLoggingCallback(Callback):
# Save the reconstructed frame to the specified folder
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
cv2.imwrite(reconstructed_path, cv2.cvtColor(reconstructed_frame[0], cv2.COLOR_RGB2BGR)) # Saving only the first image as an example
# Log images to TensorBoard
with self.writer.as_default():
@ -145,13 +147,13 @@ def main():
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
tf.random.set_seed(RANDOM_SEED)
#tf.random.set_seed(RANDOM_SEED)
# Shuffle the data using the specified seed
random.shuffle(all_videos, random.seed(RANDOM_SEED))
# Split into training and validation
split_index = int(0.6 * len(all_videos))
split_index = int(0.7 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
@ -166,7 +168,14 @@ def main():
if args.continue_training:
MODEL = tf.keras.models.load_model(args.continue_training)
MODEL = tf.keras.models.load_model(args.continue_training, custom_objects={
'VideoCompressionModel': VideoCompressionModel,
'psnr': psnr,
'ssim': ssim,
'combined': combined,
'combined_loss': combined_loss,
'combined_loss_weighted_psnr': combined_loss_weighted_psnr
})
else:
MODEL = VideoCompressionModel()