import os import json import numpy as np import cv2 import tensorflow as tf from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES from tensorflow.keras.callbacks import EarlyStopping print(tf.config.list_physical_devices('GPU')) # Constants BATCH_SIZE = 8 EPOCHS = 50 TRAIN_SAMPLES = 5 def load_list(list_path): with open(list_path, "r") as json_file: video_details_list = json.load(json_file) return video_details_list def load_frame_from_video(video_file): print("Extracting video frame...") cap = cv2.VideoCapture(video_file) ret, frame = cap.read() if not ret: return None frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) cap.release() return frame def preprocess(frame): return frame / 255.0 def save_model(model, file): os.makedirs("models", exist_ok=True) model.save(os.path.join("models/", file)) print("Model saved successfully!") def load_video_from_list(list_path): details_list = load_list(list_path) all_frames = [] all_details = [] for video_details in details_list: VIDEO_FILE = video_details["video_file"] CRF = video_details['crf'] / 63.0 PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed']) video_details['preset_speed'] = PRESET_SPEED frame = load_frame_from_video(os.path.join("test_data/", VIDEO_FILE)) if frame is not None: all_frames.append(preprocess(frame)) all_details.append({ "frame": frame, "crf": CRF, "preset_speed": PRESET_SPEED, "video_file": VIDEO_FILE }) return all_details def main(): all_video_details_train = load_video_from_list("test_data/training.json") all_video_details_val = load_video_from_list("test_data/validation.json") model = VideoCompressionModel(NUM_CHANNELS) model.compile(loss='mean_squared_error', optimizer='adam') early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True) # Prepare data all_train_frames = [] all_val_frames = [] all_crf_train = [] all_crf_val = [] all_preset_speed_train = [] all_preset_speed_val = [] for video_details_train, video_details_val in zip(all_video_details_train, all_video_details_val): all_train_frames.append(video_details_train["frame"]) all_val_frames.append(video_details_val["frame"]) all_crf_train.append(video_details_train['crf']) all_crf_val.append(video_details_val['crf']) all_preset_speed_train.append(video_details_train['preset_speed']) all_preset_speed_val.append(video_details_val['preset_speed']) # Convert lists to numpy arrays all_train_frames = np.array(all_train_frames) all_val_frames = np.array(all_val_frames) all_crf_train = np.array(all_crf_train) all_crf_val = np.array(all_crf_val) all_preset_speed_train = np.array(all_preset_speed_train) all_preset_speed_val = np.array(all_preset_speed_val) print("\nTraining the model on frame pairs...") model.fit( {"frame": all_train_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train}, all_val_frames, # Target is the compressed frame batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=({"frame": all_val_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_frames), callbacks=[early_stop] ) print("\nTraining completed!") save_model(model, 'model.keras') if __name__ == "__main__": main()