import os import json import numpy as np import cv2 import argparse import tensorflow as tf from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint print("GPUs Detected:", tf.config.list_physical_devices('GPU')) # Constants BATCH_SIZE = 16 EPOCHS = 40 LEARNING_RATE = 0.00001 TRAIN_SAMPLES = 100 MODEL_SAVE_FILE = "models/model.tf" MODEL_CHECKPOINT_DIR = "checkpoints" CONTINUE_TRAINING = None def load_list(list_path): with open(list_path, "r") as json_file: video_details_list = json.load(json_file) return video_details_list def load_video_from_list(list_path): details_list = load_list(list_path) all_details = [] num_videos = len(details_list) frames_per_video = int(TRAIN_SAMPLES / num_videos) print(f"Loading {frames_per_video} across {num_videos} videos") for video_details in details_list: VIDEO_FILE = video_details["video_file"] UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"] CRF = video_details['crf'] / 63.0 PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) video_details['preset_speed'] = PRESET_SPEED frames = [] frames_compressed = [] cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE)) cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE)) for _ in range(frames_per_video): ret, frame_compressed = cap.read() ret_uncompressed, frame = cap_uncompressed.read() if not ret or not ret_uncompressed: continue frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB) frames.append(preprocess(frame)) frames_compressed.append(preprocess(frame_compressed)) for uncompressed_frame, compressed_frame in zip(frames, frames_compressed): all_details.append({ "frame": uncompressed_frame, "compressed_frame": compressed_frame, "crf": CRF, "preset_speed": PRESET_SPEED, "video_file": VIDEO_FILE }) cap.release() cap_uncompressed.release() return all_details def preprocess(frame): return frame / 255.0 def save_model(model): os.makedirs("models", exist_ok=True) model.save(MODEL_SAVE_FILE, save_format='tf') print("Model saved successfully!") def main(): global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING # Argument parsing parser = argparse.ArgumentParser(description="Train the video compression model.") parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.') parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') args = parser.parse_args() # Use the parsed arguments in your script BATCH_SIZE = args.batch_size EPOCHS = args.epochs TRAIN_SAMPLES = args.training_samples LEARNING_RATE = args.learning_rate CONTINUE_TRAINING = args.continue_training print("Training configuration:") print(f"Batch size: {BATCH_SIZE}") print(f"Epochs: {EPOCHS}") print(f"Training samples: {TRAIN_SAMPLES}") print(f"Learning rate: {LEARNING_RATE}") print(f"Continue training from: {CONTINUE_TRAINING}") all_video_details_train = load_video_from_list("test_data/training.json") all_video_details_val = load_video_from_list("test_data/validation.json") all_train_frames = [video_details["frame"] for video_details in all_video_details_train] all_train_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_train] all_val_frames = [video_details["frame"] for video_details in all_video_details_val] all_val_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_val] all_crf_train = [video_details['crf'] for video_details in all_video_details_train] all_crf_val = [video_details['crf'] for video_details in all_video_details_val] all_preset_speed_train = [video_details['preset_speed'] for video_details in all_video_details_train] all_preset_speed_val = [video_details['preset_speed'] for video_details in all_video_details_val] # Convert lists to numpy arrays all_train_frames = np.array(all_train_frames) all_train_compressed_frames = np.array(all_train_compressed_frames) all_val_frames = np.array(all_val_frames) all_val_compressed_frames = np.array(all_val_compressed_frames) all_crf_train = np.array(all_crf_train) all_crf_val = np.array(all_crf_val) all_preset_speed_train = np.array(all_preset_speed_train) all_preset_speed_val = np.array(all_preset_speed_val) if CONTINUE_TRAINING: print("loading model:", CONTINUE_TRAINING) model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file else: model = VideoCompressionModel() # Define the optimizer with a specific learning rate optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True) checkpoint_callback = ModelCheckpoint( filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"), save_weights_only=False, save_best_only=False, verbose=1, save_format="tf" ) #tf.config.run_functions_eagerly(True) model.compile(loss='mean_squared_error', optimizer=optimizer) early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True) print("\nTraining the model...") model.fit( {"uncompressed_frame": all_train_frames, "compressed_frame": all_train_compressed_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train}, all_train_compressed_frames, # Target is the compressed frame batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=({"uncompressed_frame": all_val_frames, "compressed_frame": all_val_compressed_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_compressed_frames), callbacks=[early_stop, checkpoint_callback] ) print("\nTraining completed!") save_model(model) if __name__ == "__main__": main()