updated
This commit is contained in:
parent
db43239b3d
commit
98df94b180
3 changed files with 29 additions and 11 deletions
|
@ -66,14 +66,21 @@ class ImageLoggingCallback(Callback):
|
|||
return np.stack(converted, axis=0)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
itter = iter(self.validation_dataset)
|
||||
random_idx = np.random.randint(0, BATCH_SIZE)
|
||||
random_idx = np.random.randint(0, MAX_FRAMES - 1)
|
||||
|
||||
validation_data = None
|
||||
dataset_size = 0 # to keep track of the dataset size
|
||||
|
||||
# Loop through the dataset until the chosen index
|
||||
for i, data in enumerate(self.validation_dataset):
|
||||
if i == random_idx:
|
||||
validation_data = data
|
||||
break
|
||||
dataset_size += 1
|
||||
|
||||
if validation_data is None:
|
||||
print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.")
|
||||
validation_data = data # assigning the last data seen in the loop to validation_data
|
||||
|
||||
batch_input_images, batch_gt_labels = validation_data
|
||||
|
||||
|
@ -83,6 +90,10 @@ class ImageLoggingCallback(Callback):
|
|||
reconstructed_frame = MODEL.predict(validation_data[0])
|
||||
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
# Save the reconstructed frame to the specified folder
|
||||
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
||||
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
||||
|
||||
batch_input_images = self.convert_images(batch_input_images)
|
||||
batch_gt_labels = self.convert_images(batch_gt_labels)
|
||||
reconstructed_frame = self.convert_images(reconstructed_frame)
|
||||
|
|
Reference in a new issue