This repository has been archived on 2025-05-04. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
DeepEncode/train_model.py
2023-08-12 22:14:38 +01:00

151 lines
5.7 KiB
Python

# train_model.py
import math
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import json
import argparse
import tensorflow as tf
from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from global_train import LOGGER
# Constants
BATCH_SIZE = 4
EPOCHS = 100
LEARNING_RATE = 0.01
TRAIN_SAMPLES = 100
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
def load_video_metadata(list_path):
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
try:
with open(list_path, "r") as json_file:
file = json.load(json_file)
LOGGER.trace(f"load_video_metadata returning: {file}")
return file
except FileNotFoundError:
LOGGER.error(f"Metadata file {list_path} not found.")
raise
except json.JSONDecodeError:
LOGGER.error(f"Error decoding JSON from {list_path}.")
raise
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})")
details_list = load_video_metadata(list_path)
all_samples = []
num_videos = len(details_list)
frames_per_video = math.ceil(samples / num_videos)
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
for video_details in details_list:
compressed_video_file = video_details["compressed_video_file"]
original_video_file = video_details["original_video_file"]
crf = video_details['crf'] / 51
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
video_details['preset_speed'] = preset_speed
# Store video details without loading frames
all_samples.extend({
"frames_per_video": frames_per_video,
"crf": crf,
"preset_speed": preset_speed,
"compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file),
"original_video_file": os.path.join(os.path.dirname(list_path), original_video_file)
} for _ in range(frames_per_video))
return all_samples
def save_model(model):
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def main():
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.')
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
args = parser.parse_args()
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
TRAIN_SAMPLES = args.training_samples
LEARNING_RATE = args.learning_rate
# Display training configuration
LOGGER.info("Starting the training with the given configuration.")
LOGGER.info("Training configuration:")
LOGGER.info(f"Batch size: {BATCH_SIZE}")
LOGGER.info(f"Epochs: {EPOCHS}")
LOGGER.info(f"Training samples: {TRAIN_SAMPLES}")
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
# Load training and validation samples
LOGGER.debug("Loading training and validation samples.")
training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES)
validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10))
train_generator = VideoDataGenerator(training_samples, BATCH_SIZE)
val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE)
# Load or initialize model
if args.continue_training:
model = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss='mean_squared_error', optimizer=optimizer)
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
save_weights_only=False,
save_best_only=False,
verbose=1,
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Train the model
LOGGER.info("Starting model training.")
model.fit(
train_generator,
steps_per_epoch=len(train_generator),
epochs=EPOCHS,
validation_data=val_generator,
validation_steps=len(val_generator),
callbacks=[early_stop, checkpoint_callback]
)
LOGGER.info("Model training completed.")
save_model(model)
if __name__ == "__main__":
try:
main()
except Exception as e:
LOGGER.error(f"Unexpected error during training: {e}")
raise