This commit is contained in:
Jordon Brooks 2023-08-19 01:45:58 +01:00
parent 9f34cf8074
commit 2b4664fcbb
5 changed files with 34 additions and 11 deletions

View file

@ -9,8 +9,9 @@ TODO:
import argparse
import json
import os
import random
from featureExtraction import psnr
from featureExtraction import combined, psnr, ssim
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -110,17 +111,29 @@ def main():
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.info(f"Decay Steps: {DECAY_STEPS}")
LOGGER.info(f"Decay Rate: {DECAY_RATE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
# Specify the random seed
random_seed = 4576 # You can change this to any desired value
# Shuffle the data using the specified seed
random.shuffle(all_videos, random.seed(random_seed))
# Split into training and validation
split_index = int(0.8 * len(all_videos))
split_index = int(0.6 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
LOGGER.info(f"Training videos: {training_videos}")
LOGGER.info(f"==================================")
LOGGER.info(f"Validation videos: {validation_videos}")
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
@ -142,7 +155,7 @@ def main():
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
@ -152,7 +165,7 @@ def main():
verbose=1,
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()