This commit is contained in:
Jordon Brooks 2023-09-10 01:20:10 +01:00
parent 9cecaeb9d6
commit 4d29fffba1
No known key found for this signature in database
GPG key ID: 83964894E5D98D57
4 changed files with 126 additions and 67 deletions

View file

@ -16,7 +16,7 @@ import signal
import numpy as np
from featureExtraction import combined, combined_loss, psnr, ssim
from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, ssim
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -49,7 +49,7 @@ DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
RANDOM_SEED = 4576
RANDOM_SEED = 3545
MODEL = None
LOG_DIR = './logs'
@ -66,41 +66,27 @@ class ImageLoggingCallback(Callback):
return np.stack(converted, axis=0)
def on_epoch_end(self, epoch, logs=None):
random_idx = np.random.randint(0, MAX_FRAMES - 1)
# Get the first batch from the validation dataset
validation_data = next(iter(self.validation_dataset.take(1)))
validation_data = None
dataset_size = 0 # to keep track of the dataset size
# Loop through the dataset until the chosen index
for i, data in enumerate(self.validation_dataset):
if i == random_idx:
validation_data = data
break
dataset_size += 1
if validation_data is None:
print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.")
validation_data = data # assigning the last data seen in the loop to validation_data
batch_input_images, batch_gt_labels = validation_data
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
# Extract the inputs from the batch_input_images dictionary
actual_images = validation_data[0]['image']
batch_gt_labels = validation_data[1]
actual_images = np.clip(actual_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
# Providing all three inputs to the model for prediction
reconstructed_frame = MODEL.predict(validation_data[0])
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
# Save the reconstructed frame to the specified folder
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
batch_input_images = self.convert_images(batch_input_images)
batch_gt_labels = self.convert_images(batch_gt_labels)
reconstructed_frame = self.convert_images(reconstructed_frame)
# Log images to TensorBoard
with self.writer.as_default():
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
tf.summary.image("Input Images", actual_images, step=epoch, max_outputs=1)
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
self.writer.flush()
@ -196,7 +182,7 @@ def main():
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
MODEL.compile(loss=combined_loss_weighted_psnr, optimizer=optimizer, metrics=[psnr, ssim, combined])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(