254 lines
9.2 KiB
Python
254 lines
9.2 KiB
Python
# train_model.py
|
|
|
|
"""
|
|
TODO:
|
|
- Add more different videos with different parateters into the training set.
|
|
- Add different scenes with the same parameters
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import random
|
|
import shutil
|
|
import cv2
|
|
import subprocess
|
|
import signal
|
|
|
|
import numpy as np
|
|
|
|
from featureExtraction import combined, combined_loss, psnr, ssim
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
|
|
|
import gc
|
|
import tensorflow as tf
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard
|
|
from tensorflow.keras import backend as K
|
|
from tensorflow.summary import image as tf_image_summary
|
|
|
|
|
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
if gpus:
|
|
try:
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
except RuntimeError as e:
|
|
print(e)
|
|
|
|
|
|
from video_compression_model import VideoCompressionModel, create_dataset
|
|
|
|
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_video_metadata
|
|
|
|
# Constants
|
|
BATCH_SIZE = 25
|
|
EPOCHS = 100
|
|
LEARNING_RATE = 0.0001
|
|
DECAY_STEPS = 160
|
|
DECAY_RATE = 0.9
|
|
MODEL_SAVE_FILE = "models/model.tf"
|
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
|
EARLY_STOP = 10
|
|
RANDOM_SEED = 4576
|
|
MODEL = None
|
|
LOG_DIR = './logs'
|
|
|
|
|
|
class ImageLoggingCallback(Callback):
|
|
def __init__(self, validation_dataset, log_dir):
|
|
super().__init__()
|
|
self.validation_dataset = validation_dataset
|
|
self.log_dir = log_dir
|
|
self.writer = tf.summary.create_file_writer(self.log_dir)
|
|
|
|
def convert_images(self, images):
|
|
converted = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
|
|
return np.stack(converted, axis=0)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
random_idx = np.random.randint(0, MAX_FRAMES - 1)
|
|
|
|
validation_data = None
|
|
dataset_size = 0 # to keep track of the dataset size
|
|
|
|
# Loop through the dataset until the chosen index
|
|
for i, data in enumerate(self.validation_dataset):
|
|
if i == random_idx:
|
|
validation_data = data
|
|
break
|
|
dataset_size += 1
|
|
|
|
if validation_data is None:
|
|
print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.")
|
|
validation_data = data # assigning the last data seen in the loop to validation_data
|
|
|
|
batch_input_images, batch_gt_labels = validation_data
|
|
|
|
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
|
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
reconstructed_frame = MODEL.predict(validation_data[0])
|
|
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
|
|
|
|
# Save the reconstructed frame to the specified folder
|
|
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
|
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
|
|
|
batch_input_images = self.convert_images(batch_input_images)
|
|
batch_gt_labels = self.convert_images(batch_gt_labels)
|
|
reconstructed_frame = self.convert_images(reconstructed_frame)
|
|
|
|
# Log images to TensorBoard
|
|
with self.writer.as_default():
|
|
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
|
|
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
|
|
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
|
|
self.writer.flush()
|
|
|
|
|
|
|
|
class GarbageCollectorCallback(Callback):
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
LOGGER.debug(f"GC")
|
|
gc.collect()
|
|
|
|
def save_model():
|
|
try:
|
|
LOGGER.debug("Attempting to save the model.")
|
|
os.makedirs("models", exist_ok=True)
|
|
MODEL.save(MODEL_SAVE_FILE, save_format='tf')
|
|
LOGGER.info("Model saved successfully!")
|
|
except Exception as e:
|
|
LOGGER.error(f"Error saving the model: {e}")
|
|
raise
|
|
|
|
|
|
def main():
|
|
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE, MODEL
|
|
# Argument parsing
|
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
|
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
|
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
|
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
|
|
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
|
|
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
|
|
args = parser.parse_args()
|
|
|
|
BATCH_SIZE = args.batch_size
|
|
EPOCHS = args.epochs
|
|
LEARNING_RATE = args.learning_rate
|
|
MAX_FRAMES = args.max_frames
|
|
DECAY_RATE = args.decay_rate
|
|
DECAY_STEPS = args.decay_steps
|
|
|
|
# Display training configuration
|
|
LOGGER.info("Starting the training with the given configuration.")
|
|
LOGGER.info("Training configuration:")
|
|
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
|
LOGGER.info(f"Epochs: {EPOCHS}")
|
|
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
|
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
|
|
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
|
LOGGER.info(f"Decay Steps: {DECAY_STEPS}")
|
|
LOGGER.info(f"Decay Rate: {DECAY_RATE}")
|
|
|
|
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
|
|
|
# Load all video metadata
|
|
all_videos = load_video_metadata("test_data/validation/validation.json")
|
|
|
|
tf.random.set_seed(RANDOM_SEED)
|
|
|
|
# Shuffle the data using the specified seed
|
|
random.shuffle(all_videos, random.seed(RANDOM_SEED))
|
|
|
|
# Split into training and validation
|
|
split_index = int(0.6 * len(all_videos))
|
|
training_videos = all_videos[:split_index]
|
|
validation_videos = all_videos[split_index:]
|
|
|
|
LOGGER.info(f"Training videos: {training_videos}")
|
|
LOGGER.info(f"==================================")
|
|
LOGGER.info(f"Validation videos: {validation_videos}")
|
|
|
|
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
|
|
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
|
|
|
|
tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0, write_graph=True, update_freq='epoch')
|
|
|
|
|
|
if args.continue_training:
|
|
MODEL = tf.keras.models.load_model(args.continue_training)
|
|
else:
|
|
MODEL = VideoCompressionModel()
|
|
|
|
|
|
# Define exponential decay schedule
|
|
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
|
initial_learning_rate=LEARNING_RATE,
|
|
decay_steps=DECAY_STEPS,
|
|
decay_rate=DECAY_RATE,
|
|
staircase=True
|
|
)
|
|
|
|
|
|
# Set optimizer and compile the model
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
|
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
|
|
|
|
# Define checkpoints and early stopping
|
|
checkpoint_callback = ModelCheckpoint(
|
|
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
|
save_weights_only=False,
|
|
save_best_only=False,
|
|
verbose=1,
|
|
save_format="tf"
|
|
)
|
|
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
|
|
|
ImageSnapshots = ImageLoggingCallback(validation_dataset, LOG_DIR)
|
|
|
|
|
|
# Custom garbage collection callback
|
|
gc_callback = GarbageCollectorCallback()
|
|
|
|
gc.collect()
|
|
|
|
# Train the model
|
|
LOGGER.info("Starting model training.")
|
|
MODEL.fit(
|
|
training_dataset,
|
|
epochs=EPOCHS,
|
|
validation_data=validation_dataset,
|
|
callbacks=[early_stop, checkpoint_callback, gc_callback, tensorboard_callback, ImageSnapshots]
|
|
)
|
|
LOGGER.info("Model training completed.")
|
|
|
|
save_model()
|
|
|
|
def preMain():
|
|
# Delete the existing logs directory and create a new one
|
|
if os.path.exists(LOG_DIR):
|
|
shutil.rmtree(LOG_DIR)
|
|
os.makedirs(LOG_DIR, exist_ok=True)
|
|
|
|
# Start TensorBoard as a subprocess
|
|
LOGGER.info("Running tensorboard at: http://localhost:6006/")
|
|
tensorboard_process = subprocess.Popen(['tensorboard', '--logdir', './logs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid)
|
|
return tensorboard_process
|
|
|
|
if __name__ == "__main__":
|
|
clear_screen()
|
|
|
|
tensorboard_process = preMain()
|
|
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
LOGGER.error(f"Unexpected error during training: {e}")
|
|
raise
|
|
finally:
|
|
# Ensure TensorBoard process is terminated when main script ends
|
|
os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)
|