This commit is contained in:
Jordon Brooks 2023-08-16 22:45:16 +01:00
parent 54fa90247a
commit 15d8e57da5
4 changed files with 56 additions and 12 deletions

View file

@ -1,8 +1,16 @@
# train_model.py
"""
TODO:
- Add more different videos with different parateters into the training set.
- Add different scenes with the same parameters
"""
import argparse
import json
import os
import cv2
import numpy as np
from featureExtraction import psnr
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -16,10 +24,12 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 0.000001
LEARNING_RATE = 0.001
DECAY_STEPS = 40
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
EARLY_STOP = 5
def save_model(model):
try:
@ -58,7 +68,7 @@ def load_video_metadata(list_path):
def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
@ -66,12 +76,16 @@ def main():
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
args = parser.parse_args()
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate
MAX_FRAMES = args.max_frames
DECAY_RATE = args.decay_rate
DECAY_STEPS = args.decay_steps
# Display training configuration
LOGGER.info("Starting the training with the given configuration.")
@ -96,11 +110,20 @@ def main():
model = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
# Define exponential decay schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=LEARNING_RATE,
decay_steps=DECAY_STEPS,
decay_rate=DECAY_RATE,
staircase=False
)
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss='mean_squared_error', optimizer=optimizer)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(