import os import tensorflow as tf import numpy as np import cv2 from video_compression_model import VideoCompressionModel # Constants NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels) BATCH_SIZE = 32 # Batch size used during training EPOCHS = 20 # Number of training epochs # Step 1: Data Preparation TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name VAL_VIDEO_FILE = 'training_video.mkv' # The validation video file name TRAIN_SAMPLES = 2 # Number of video frames used for training VAL_SAMPLES = 2 # Number of video frames used for validation def load_frames_from_video(video_file, num_frames): print("Extracting video frames...") cap = cv2.VideoCapture(video_file) frames = [] count = 0 frame_width, frame_height = None, None # Initialize the frame dimensions while True: ret, frame = cap.read() if not ret: break if frame_width is None or frame_height is None: frame_height, frame_width = frame.shape[:2] # Get the frame dimensions from the first frame frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) count += 1 if count >= num_frames: break cap.release() return frames, frame_width, frame_height # Return frames and frame dimensions train_frames, FRAME_WIDTH, FRAME_HEIGHT = load_frames_from_video(TRAIN_VIDEO_FILE, num_frames=TRAIN_SAMPLES) val_frames, _, _ = load_frames_from_video(VAL_VIDEO_FILE, num_frames=VAL_SAMPLES) print("Number of training frames:", len(train_frames)) print("Number of validation frames:", len(val_frames)) def preprocess(frames): frames = np.array(frames) / 255.0 return frames train_frames = preprocess(train_frames) val_frames = preprocess(val_frames) print("training frames:", len(train_frames)) print("validation frames:", len(val_frames)) # Step 2: Model Architecture model = VideoCompressionModel() model.compile(loss='mean_squared_error', optimizer='adam', run_eagerly=True) # Adjusting the input shape for training and validation frame_height, frame_width = train_frames[0].shape[:2] # Use the resized frames as target data train_targets = train_frames val_targets = val_frames # Create the "models" directory if it doesn't exist os.makedirs("models", exist_ok=True) print("\nTraining the model...") model.fit( train_frames, [train_targets, tf.zeros_like(train_targets)], batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)]) ) print("\nTraining completed.") # Step 3: Save the trained model model.save('ai_rate_control_model.keras') print("Model saved successfully!")