Update
This commit is contained in:
parent
9cecaeb9d6
commit
4d29fffba1
4 changed files with 126 additions and 67 deletions
|
@ -16,7 +16,7 @@ import signal
|
|||
|
||||
import numpy as np
|
||||
|
||||
from featureExtraction import combined, combined_loss, psnr, ssim
|
||||
from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, ssim
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
|
@ -49,7 +49,7 @@ DECAY_RATE = 0.9
|
|||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
RANDOM_SEED = 4576
|
||||
RANDOM_SEED = 3545
|
||||
MODEL = None
|
||||
LOG_DIR = './logs'
|
||||
|
||||
|
@ -66,41 +66,27 @@ class ImageLoggingCallback(Callback):
|
|||
return np.stack(converted, axis=0)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
random_idx = np.random.randint(0, MAX_FRAMES - 1)
|
||||
# Get the first batch from the validation dataset
|
||||
validation_data = next(iter(self.validation_dataset.take(1)))
|
||||
|
||||
validation_data = None
|
||||
dataset_size = 0 # to keep track of the dataset size
|
||||
|
||||
# Loop through the dataset until the chosen index
|
||||
for i, data in enumerate(self.validation_dataset):
|
||||
if i == random_idx:
|
||||
validation_data = data
|
||||
break
|
||||
dataset_size += 1
|
||||
|
||||
if validation_data is None:
|
||||
print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.")
|
||||
validation_data = data # assigning the last data seen in the loop to validation_data
|
||||
|
||||
batch_input_images, batch_gt_labels = validation_data
|
||||
|
||||
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
||||
# Extract the inputs from the batch_input_images dictionary
|
||||
actual_images = validation_data[0]['image']
|
||||
batch_gt_labels = validation_data[1]
|
||||
|
||||
actual_images = np.clip(actual_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
||||
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
# Providing all three inputs to the model for prediction
|
||||
reconstructed_frame = MODEL.predict(validation_data[0])
|
||||
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
# Save the reconstructed frame to the specified folder
|
||||
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
||||
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
||||
|
||||
batch_input_images = self.convert_images(batch_input_images)
|
||||
batch_gt_labels = self.convert_images(batch_gt_labels)
|
||||
reconstructed_frame = self.convert_images(reconstructed_frame)
|
||||
|
||||
# Log images to TensorBoard
|
||||
with self.writer.as_default():
|
||||
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
|
||||
tf.summary.image("Input Images", actual_images, step=epoch, max_outputs=1)
|
||||
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
|
||||
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
|
||||
self.writer.flush()
|
||||
|
@ -196,7 +182,7 @@ def main():
|
|||
|
||||
# Set optimizer and compile the model
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||
MODEL.compile(loss=combined_loss_weighted_psnr, optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||
|
||||
# Define checkpoints and early stopping
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
|
|
Reference in a new issue