# train_model.py """ TODO: - Add more different videos with different parateters into the training set. - Add different scenes with the same parameters """ import argparse import os import random import shutil import cv2 import subprocess import signal import numpy as np from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, ssim os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' import gc import tensorflow as tf from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard from tensorflow.keras import backend as K from tensorflow.summary import image as tf_image_summary gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) from video_compression_model import VideoCompressionModel, create_dataset from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_video_metadata # Constants BATCH_SIZE = 25 EPOCHS = 1000 LEARNING_RATE = 0.005 DECAY_STEPS = 160 DECAY_RATE = 0.9 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" EARLY_STOP = 10 RANDOM_SEED = 3545 MODEL = None LOG_DIR = './logs' class ImageLoggingCallback(Callback): def __init__(self, validation_dataset, log_dir): super().__init__() self.validation_dataset = validation_dataset self.log_dir = log_dir self.writer = tf.summary.create_file_writer(self.log_dir) def convert_images(self, images): converted = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] return np.stack(converted, axis=0) def on_epoch_end(self, epoch, logs=None): # where total_batches is the number of batches in the validation dataset skip_batches = np.random.randint(0, 100) # Get the first batch from the validation dataset validation_data = next(iter(self.validation_dataset.skip(skip_batches).take(1))) # Extract the inputs from the batch_input_images dictionary actual_images = validation_data[0]['image'] batch_gt_labels = validation_data[1] actual_images = np.clip(actual_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8) batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8) # Providing all three inputs to the model for prediction reconstructed_frame = MODEL.predict(validation_data[0]) reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8) # Save the reconstructed frame to the specified folder reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png") cv2.imwrite(reconstructed_path, cv2.cvtColor(reconstructed_frame[0], cv2.COLOR_RGB2BGR)) # Saving only the first image as an example # Log images to TensorBoard with self.writer.as_default(): tf.summary.image("Input Images", actual_images, step=epoch, max_outputs=1) tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1) tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3) self.writer.flush() class GarbageCollectorCallback(Callback): def on_epoch_end(self, epoch, logs=None): LOGGER.debug(f"GC") gc.collect() def save_model(): try: LOGGER.debug("Attempting to save the model.") os.makedirs("models", exist_ok=True) MODEL.save(MODEL_SAVE_FILE, save_format='tf') LOGGER.info("Model saved successfully!") except Exception as e: LOGGER.error(f"Error saving the model: {e}") raise def main(): global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE, MODEL # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.') parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.') parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.') args = parser.parse_args() BATCH_SIZE = args.batch_size EPOCHS = args.epochs LEARNING_RATE = args.learning_rate MAX_FRAMES = args.max_frames DECAY_RATE = args.decay_rate DECAY_STEPS = args.decay_steps # Display training configuration LOGGER.info("Starting the training with the given configuration.") LOGGER.info("Training configuration:") LOGGER.info(f"Batch size: {BATCH_SIZE}") LOGGER.info(f"Epochs: {EPOCHS}") LOGGER.info(f"Learning rate: {LEARNING_RATE}") LOGGER.info(f"Max Frames: {MAX_FRAMES}") LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}") LOGGER.info(f"Decay Steps: {DECAY_STEPS}") LOGGER.info(f"Decay Rate: {DECAY_RATE}") LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}") # Load all video metadata all_videos = load_video_metadata("test_data/validation/validation.json") #tf.random.set_seed(RANDOM_SEED) # Shuffle the data using the specified seed random.shuffle(all_videos, random.seed(RANDOM_SEED)) # Split into training and validation split_index = int(0.7 * len(all_videos)) training_videos = all_videos[:split_index] validation_videos = all_videos[split_index:] LOGGER.info(f"Training videos: {training_videos}") LOGGER.info(f"==================================") LOGGER.info(f"Validation videos: {validation_videos}") training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES) validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES) tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0, write_graph=True, update_freq='epoch') if args.continue_training: MODEL = tf.keras.models.load_model(args.continue_training, custom_objects={ 'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss, 'combined_loss_weighted_psnr': combined_loss_weighted_psnr }) else: MODEL = VideoCompressionModel() # Define exponential decay schedule lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=LEARNING_RATE, decay_steps=DECAY_STEPS, decay_rate=DECAY_RATE, staircase=True ) # Set optimizer and compile the model optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) MODEL.compile(loss=combined_loss_weighted_psnr, optimizer=optimizer, metrics=[psnr, ssim, combined]) # Define checkpoints and early stopping checkpoint_callback = ModelCheckpoint( filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"), save_weights_only=False, save_best_only=False, verbose=1, save_format="tf" ) early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True) ImageSnapshots = ImageLoggingCallback(validation_dataset, LOG_DIR) # Custom garbage collection callback gc_callback = GarbageCollectorCallback() gc.collect() # Train the model LOGGER.info("Starting model training.") MODEL.fit( training_dataset, epochs=EPOCHS, validation_data=validation_dataset, callbacks=[early_stop, checkpoint_callback, gc_callback, tensorboard_callback, ImageSnapshots] ) LOGGER.info("Model training completed.") save_model() def preMain(): # Delete the existing logs directory and create a new one if os.path.exists(LOG_DIR): shutil.rmtree(LOG_DIR) os.makedirs(LOG_DIR, exist_ok=True) # Start TensorBoard as a subprocess LOGGER.info("Running tensorboard at: http://localhost:6006/") tensorboard_process = subprocess.Popen(['tensorboard', '--logdir', './logs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid) return tensorboard_process if __name__ == "__main__": clear_screen() tensorboard_process = preMain() try: main() except Exception as e: LOGGER.error(f"Unexpected error during training: {e}") raise finally: # Ensure TensorBoard process is terminated when main script ends os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)