Better
This commit is contained in:
parent
9f34cf8074
commit
2b4664fcbb
5 changed files with 34 additions and 11 deletions
|
@ -9,8 +9,9 @@ TODO:
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from featureExtraction import psnr
|
||||
from featureExtraction import combined, psnr, ssim
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
|
@ -110,17 +111,29 @@ def main():
|
|||
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
||||
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
|
||||
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||
LOGGER.info(f"Decay Steps: {DECAY_STEPS}")
|
||||
LOGGER.info(f"Decay Rate: {DECAY_RATE}")
|
||||
|
||||
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||
|
||||
# Load all video metadata
|
||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||
|
||||
# Specify the random seed
|
||||
random_seed = 4576 # You can change this to any desired value
|
||||
|
||||
# Shuffle the data using the specified seed
|
||||
random.shuffle(all_videos, random.seed(random_seed))
|
||||
|
||||
# Split into training and validation
|
||||
split_index = int(0.8 * len(all_videos))
|
||||
split_index = int(0.6 * len(all_videos))
|
||||
training_videos = all_videos[:split_index]
|
||||
validation_videos = all_videos[split_index:]
|
||||
|
||||
LOGGER.info(f"Training videos: {training_videos}")
|
||||
LOGGER.info(f"==================================")
|
||||
LOGGER.info(f"Validation videos: {validation_videos}")
|
||||
|
||||
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
|
||||
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
|
||||
|
||||
|
@ -142,7 +155,7 @@ def main():
|
|||
|
||||
# Set optimizer and compile the model
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
|
||||
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||
|
||||
# Define checkpoints and early stopping
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
|
@ -152,7 +165,7 @@ def main():
|
|||
verbose=1,
|
||||
save_format="tf"
|
||||
)
|
||||
early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Custom garbage collection callback
|
||||
gc_callback = GarbageCollectorCallback()
|
||||
|
|
Reference in a new issue