This commit is contained in:
Jordon Brooks 2023-08-13 02:07:13 +01:00
parent 9ae5921e2b
commit 1d98bc84a2

View file

@ -1,3 +1,4 @@
import argparse
import json
import os
import cv2
@ -132,7 +133,6 @@ def main():
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.')
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
@ -140,7 +140,6 @@ def main():
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
TRAIN_SAMPLES = args.training_samples
LEARNING_RATE = args.learning_rate
# Display training configuration
@ -148,7 +147,6 @@ def main():
LOGGER.info("Training configuration:")
LOGGER.info(f"Batch size: {BATCH_SIZE}")
LOGGER.info(f"Epochs: {EPOCHS}")
LOGGER.info(f"Training samples: {TRAIN_SAMPLES}")
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")