Max Frames now works
This commit is contained in:
parent
185e3fac9a
commit
491aaf402e
1 changed files with 12 additions and 6 deletions
|
@ -11,12 +11,12 @@ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
|||
|
||||
from video_compression_model import VideoCompressionModel, data_generator
|
||||
|
||||
from globalVars import HEIGHT, WIDTH, LOGGER
|
||||
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 5
|
||||
LEARNING_RATE = 0.01
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.001
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
|
@ -58,18 +58,20 @@ def load_video_metadata(list_path):
|
|||
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
||||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
|
||||
args = parser.parse_args()
|
||||
|
||||
BATCH_SIZE = args.batch_size
|
||||
EPOCHS = args.epochs
|
||||
LEARNING_RATE = args.learning_rate
|
||||
MAX_FRAMES = args.max_frames
|
||||
|
||||
# Display training configuration
|
||||
LOGGER.info("Starting the training with the given configuration.")
|
||||
|
@ -77,10 +79,10 @@ def main():
|
|||
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
||||
LOGGER.info(f"Epochs: {EPOCHS}")
|
||||
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
||||
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
|
||||
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||
|
||||
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||
LOGGER.trace("Hello, World!")
|
||||
|
||||
# Load all video metadata
|
||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||
|
@ -111,7 +113,11 @@ def main():
|
|||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Calculate steps per epoch for training and validation
|
||||
if MAX_FRAMES <= 0:
|
||||
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
||||
else:
|
||||
average_frames_per_video = max(MAX_FRAMES, 0)
|
||||
|
||||
total_frames_train = average_frames_per_video * len(training_videos)
|
||||
total_frames_validation = average_frames_per_video * len(validation_videos)
|
||||
steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
||||
|
|
Reference in a new issue