This repository has been archived on 2025-05-04. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
DeepEncode/train_model.py
2023-08-23 00:54:06 +01:00

243 lines
8.6 KiB
Python

# train_model.py
"""
TODO:
- Add more different videos with different parateters into the training set.
- Add different scenes with the same parameters
"""
import argparse
import os
import random
import shutil
import cv2
import subprocess
import signal
import numpy as np
from featureExtraction import combined, combined_loss, psnr, ssim
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import gc
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard
from tensorflow.keras import backend as K
from tensorflow.summary import image as tf_image_summary
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
from video_compression_model import VideoCompressionModel, create_dataset
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_video_metadata
# Constants
BATCH_SIZE = 25
EPOCHS = 100
LEARNING_RATE = 0.0001
DECAY_STEPS = 160
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
RANDOM_SEED = 4576
MODEL = None
LOG_DIR = './logs'
class ImageLoggingCallback(Callback):
def __init__(self, validation_dataset, log_dir):
super().__init__()
self.validation_dataset = validation_dataset
self.log_dir = log_dir
self.writer = tf.summary.create_file_writer(self.log_dir)
def convert_images(self, images):
converted = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
return np.stack(converted, axis=0)
def on_epoch_end(self, epoch, logs=None):
itter = iter(self.validation_dataset)
random_idx = np.random.randint(0, BATCH_SIZE)
# Loop through the dataset until the chosen index
for i, data in enumerate(self.validation_dataset):
if i == random_idx:
validation_data = data
break
batch_input_images, batch_gt_labels = validation_data
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
reconstructed_frame = MODEL.predict(validation_data[0])
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
batch_input_images = self.convert_images(batch_input_images)
batch_gt_labels = self.convert_images(batch_gt_labels)
reconstructed_frame = self.convert_images(reconstructed_frame)
# Log images to TensorBoard
with self.writer.as_default():
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
self.writer.flush()
class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"GC")
gc.collect()
def save_model():
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
MODEL.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE, MODEL
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
args = parser.parse_args()
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate
MAX_FRAMES = args.max_frames
DECAY_RATE = args.decay_rate
DECAY_STEPS = args.decay_steps
# Display training configuration
LOGGER.info("Starting the training with the given configuration.")
LOGGER.info("Training configuration:")
LOGGER.info(f"Batch size: {BATCH_SIZE}")
LOGGER.info(f"Epochs: {EPOCHS}")
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.info(f"Decay Steps: {DECAY_STEPS}")
LOGGER.info(f"Decay Rate: {DECAY_RATE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
tf.random.set_seed(RANDOM_SEED)
# Shuffle the data using the specified seed
random.shuffle(all_videos, random.seed(RANDOM_SEED))
# Split into training and validation
split_index = int(0.6 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
LOGGER.info(f"Training videos: {training_videos}")
LOGGER.info(f"==================================")
LOGGER.info(f"Validation videos: {validation_videos}")
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0, write_graph=True, update_freq='epoch')
if args.continue_training:
MODEL = tf.keras.models.load_model(args.continue_training)
else:
MODEL = VideoCompressionModel()
# Define exponential decay schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=LEARNING_RATE,
decay_steps=DECAY_STEPS,
decay_rate=DECAY_RATE,
staircase=True
)
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
save_weights_only=False,
save_best_only=False,
verbose=1,
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
ImageSnapshots = ImageLoggingCallback(validation_dataset, LOG_DIR)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()
gc.collect()
# Train the model
LOGGER.info("Starting model training.")
MODEL.fit(
training_dataset,
epochs=EPOCHS,
validation_data=validation_dataset,
callbacks=[early_stop, checkpoint_callback, gc_callback, tensorboard_callback, ImageSnapshots]
)
LOGGER.info("Model training completed.")
save_model()
def preMain():
# Delete the existing logs directory and create a new one
if os.path.exists(LOG_DIR):
shutil.rmtree(LOG_DIR)
os.makedirs(LOG_DIR, exist_ok=True)
# Start TensorBoard as a subprocess
LOGGER.info("Running tensorboard at: http://localhost:6006/")
tensorboard_process = subprocess.Popen(['tensorboard', '--logdir', './logs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid)
return tensorboard_process
if __name__ == "__main__":
clear_screen()
tensorboard_process = preMain()
try:
main()
except Exception as e:
LOGGER.error(f"Unexpected error during training: {e}")
raise
finally:
# Ensure TensorBoard process is terminated when main script ends
os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)