This repository has been archived on 2025-05-04. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
DeepEncode/train_model.py
2023-08-17 01:57:53 +01:00

187 lines
No EOL
6.4 KiB
Python

# train_model.py
"""
TODO:
- Add more different videos with different parateters into the training set.
- Add different scenes with the same parameters
"""
import argparse
import json
import os
from featureExtraction import psnr
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import gc
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
from video_compression_model import VideoCompressionModel, data_generator
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants
BATCH_SIZE = 25
EPOCHS = 100
LEARNING_RATE = 0.0001
DECAY_STEPS = 160
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 5
class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"Collecting garbage")
gc.collect()
def save_model(model):
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def load_video_metadata(list_path):
"""
Load video metadata from a JSON file.
Args:
- json_path (str): Path to the JSON file containing video metadata.
Returns:
- list: List of dictionaries, each containing video details.
"""
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
try:
with open(list_path, "r") as json_file:
file = json.load(json_file)
LOGGER.trace(f"load_video_metadata returning: {file}")
return file
except FileNotFoundError:
LOGGER.error(f"Metadata file {list_path} not found.")
raise
except json.JSONDecodeError:
LOGGER.error(f"Error decoding JSON from {list_path}.")
raise
def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
args = parser.parse_args()
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate
MAX_FRAMES = args.max_frames
DECAY_RATE = args.decay_rate
DECAY_STEPS = args.decay_steps
# Display training configuration
LOGGER.info("Starting the training with the given configuration.")
LOGGER.info("Training configuration:")
LOGGER.info(f"Batch size: {BATCH_SIZE}")
LOGGER.info(f"Epochs: {EPOCHS}")
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
# Split into training and validation
split_index = int(0.8 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
if args.continue_training:
model = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
# Define exponential decay schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=LEARNING_RATE,
decay_steps=DECAY_STEPS,
decay_rate=DECAY_RATE,
staircase=False
)
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
save_weights_only=False,
save_best_only=False,
verbose=1,
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()
# Calculate steps per epoch for training and validation
if MAX_FRAMES <= 0:
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
else:
average_frames_per_video = max(MAX_FRAMES, 0)
total_frames_train = average_frames_per_video * len(training_videos)
total_frames_validation = average_frames_per_video * len(validation_videos)
steps_per_epoch_train = total_frames_train // BATCH_SIZE
steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
# Train the model
LOGGER.info("Starting model training.")
model.fit(
data_generator(training_videos, BATCH_SIZE),
epochs=EPOCHS,
steps_per_epoch=steps_per_epoch_train,
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
validation_steps=steps_per_epoch_validation, # Add validation steps here
callbacks=[early_stop, checkpoint_callback, gc_callback]
)
LOGGER.info("Model training completed.")
save_model(model)
if __name__ == "__main__":
try:
main()
except Exception as e:
LOGGER.error(f"Unexpected error during training: {e}")
raise