Max Frames now works
This commit is contained in:
parent
185e3fac9a
commit
491aaf402e
1 changed files with 12 additions and 6 deletions
|
@ -11,12 +11,12 @@ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
from video_compression_model import VideoCompressionModel, data_generator
|
from video_compression_model import VideoCompressionModel, data_generator
|
||||||
|
|
||||||
from globalVars import HEIGHT, WIDTH, LOGGER
|
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 5
|
EPOCHS = 100
|
||||||
LEARNING_RATE = 0.01
|
LEARNING_RATE = 0.001
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 10
|
||||||
|
@ -58,18 +58,20 @@ def load_video_metadata(list_path):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE
|
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
|
||||||
# Argument parsing
|
# Argument parsing
|
||||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||||
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
||||||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||||
|
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
BATCH_SIZE = args.batch_size
|
BATCH_SIZE = args.batch_size
|
||||||
EPOCHS = args.epochs
|
EPOCHS = args.epochs
|
||||||
LEARNING_RATE = args.learning_rate
|
LEARNING_RATE = args.learning_rate
|
||||||
|
MAX_FRAMES = args.max_frames
|
||||||
|
|
||||||
# Display training configuration
|
# Display training configuration
|
||||||
LOGGER.info("Starting the training with the given configuration.")
|
LOGGER.info("Starting the training with the given configuration.")
|
||||||
|
@ -77,10 +79,10 @@ def main():
|
||||||
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
||||||
LOGGER.info(f"Epochs: {EPOCHS}")
|
LOGGER.info(f"Epochs: {EPOCHS}")
|
||||||
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
||||||
|
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
|
||||||
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||||
|
|
||||||
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||||
LOGGER.trace("Hello, World!")
|
|
||||||
|
|
||||||
# Load all video metadata
|
# Load all video metadata
|
||||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||||
|
@ -111,7 +113,11 @@ def main():
|
||||||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||||
|
|
||||||
# Calculate steps per epoch for training and validation
|
# Calculate steps per epoch for training and validation
|
||||||
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
if MAX_FRAMES <= 0:
|
||||||
|
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
||||||
|
else:
|
||||||
|
average_frames_per_video = max(MAX_FRAMES, 0)
|
||||||
|
|
||||||
total_frames_train = average_frames_per_video * len(training_videos)
|
total_frames_train = average_frames_per_video * len(training_videos)
|
||||||
total_frames_validation = average_frames_per_video * len(validation_videos)
|
total_frames_validation = average_frames_per_video * len(validation_videos)
|
||||||
steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
||||||
|
|
Reference in a new issue