Max Frames now works

This commit is contained in:
Jordon Brooks 2023-08-13 15:56:51 +01:00
parent 185e3fac9a
commit 491aaf402e

View file

@ -11,12 +11,12 @@ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from video_compression_model import VideoCompressionModel, data_generator from video_compression_model import VideoCompressionModel, data_generator
from globalVars import HEIGHT, WIDTH, LOGGER from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants # Constants
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 5 EPOCHS = 100
LEARNING_RATE = 0.01 LEARNING_RATE = 0.001
MODEL_SAVE_FILE = "models/model.tf" MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints" MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10 EARLY_STOP = 10
@ -58,18 +58,20 @@ def load_video_metadata(list_path):
def main(): def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
# Argument parsing # Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.") parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.') parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
args = parser.parse_args() args = parser.parse_args()
BATCH_SIZE = args.batch_size BATCH_SIZE = args.batch_size
EPOCHS = args.epochs EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate LEARNING_RATE = args.learning_rate
MAX_FRAMES = args.max_frames
# Display training configuration # Display training configuration
LOGGER.info("Starting the training with the given configuration.") LOGGER.info("Starting the training with the given configuration.")
@ -77,10 +79,10 @@ def main():
LOGGER.info(f"Batch size: {BATCH_SIZE}") LOGGER.info(f"Batch size: {BATCH_SIZE}")
LOGGER.info(f"Epochs: {EPOCHS}") LOGGER.info(f"Epochs: {EPOCHS}")
LOGGER.info(f"Learning rate: {LEARNING_RATE}") LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}") LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}") LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
LOGGER.trace("Hello, World!")
# Load all video metadata # Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json") all_videos = load_video_metadata("test_data/validation/validation.json")
@ -111,7 +113,11 @@ def main():
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True) early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Calculate steps per epoch for training and validation # Calculate steps per epoch for training and validation
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps if MAX_FRAMES <= 0:
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
else:
average_frames_per_video = max(MAX_FRAMES, 0)
total_frames_train = average_frames_per_video * len(training_videos) total_frames_train = average_frames_per_video * len(training_videos)
total_frames_validation = average_frames_per_video * len(validation_videos) total_frames_validation = average_frames_per_video * len(validation_videos)
steps_per_epoch_train = total_frames_train // BATCH_SIZE steps_per_epoch_train = total_frames_train // BATCH_SIZE