update
This commit is contained in:
parent
4d29fffba1
commit
8df4df7972
3 changed files with 51 additions and 20 deletions
|
@ -42,8 +42,8 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_vid
|
|||
|
||||
# Constants
|
||||
BATCH_SIZE = 25
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.0001
|
||||
EPOCHS = 1000
|
||||
LEARNING_RATE = 0.005
|
||||
DECAY_STEPS = 160
|
||||
DECAY_RATE = 0.9
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
|
@ -66,8 +66,10 @@ class ImageLoggingCallback(Callback):
|
|||
return np.stack(converted, axis=0)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
# where total_batches is the number of batches in the validation dataset
|
||||
skip_batches = np.random.randint(0, 100)
|
||||
# Get the first batch from the validation dataset
|
||||
validation_data = next(iter(self.validation_dataset.take(1)))
|
||||
validation_data = next(iter(self.validation_dataset.skip(skip_batches).take(1)))
|
||||
|
||||
# Extract the inputs from the batch_input_images dictionary
|
||||
actual_images = validation_data[0]['image']
|
||||
|
@ -82,7 +84,7 @@ class ImageLoggingCallback(Callback):
|
|||
|
||||
# Save the reconstructed frame to the specified folder
|
||||
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
||||
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
||||
cv2.imwrite(reconstructed_path, cv2.cvtColor(reconstructed_frame[0], cv2.COLOR_RGB2BGR)) # Saving only the first image as an example
|
||||
|
||||
# Log images to TensorBoard
|
||||
with self.writer.as_default():
|
||||
|
@ -145,13 +147,13 @@ def main():
|
|||
# Load all video metadata
|
||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||
|
||||
tf.random.set_seed(RANDOM_SEED)
|
||||
#tf.random.set_seed(RANDOM_SEED)
|
||||
|
||||
# Shuffle the data using the specified seed
|
||||
random.shuffle(all_videos, random.seed(RANDOM_SEED))
|
||||
|
||||
# Split into training and validation
|
||||
split_index = int(0.6 * len(all_videos))
|
||||
split_index = int(0.7 * len(all_videos))
|
||||
training_videos = all_videos[:split_index]
|
||||
validation_videos = all_videos[split_index:]
|
||||
|
||||
|
@ -166,7 +168,14 @@ def main():
|
|||
|
||||
|
||||
if args.continue_training:
|
||||
MODEL = tf.keras.models.load_model(args.continue_training)
|
||||
MODEL = tf.keras.models.load_model(args.continue_training, custom_objects={
|
||||
'VideoCompressionModel': VideoCompressionModel,
|
||||
'psnr': psnr,
|
||||
'ssim': ssim,
|
||||
'combined': combined,
|
||||
'combined_loss': combined_loss,
|
||||
'combined_loss_weighted_psnr': combined_loss_weighted_psnr
|
||||
})
|
||||
else:
|
||||
MODEL = VideoCompressionModel()
|
||||
|
||||
|
|
Reference in a new issue