This repository has been archived on 2025-05-04. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
DeepEncode/train_model.py

81 lines
2.7 KiB
Python

import os
import tensorflow as tf
import numpy as np
import cv2
from video_compression_model import VideoCompressionModel
# Constants
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
BATCH_SIZE = 32 # Batch size used during training
EPOCHS = 20 # Number of training epochs
# Step 1: Data Preparation
TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name
VAL_VIDEO_FILE = 'training_video.mkv' # The validation video file name
TRAIN_SAMPLES = 2 # Number of video frames used for training
VAL_SAMPLES = 2 # Number of video frames used for validation
def load_frames_from_video(video_file, num_frames):
print("Extracting video frames...")
cap = cv2.VideoCapture(video_file)
frames = []
count = 0
frame_width, frame_height = None, None # Initialize the frame dimensions
while True:
ret, frame = cap.read()
if not ret:
break
if frame_width is None or frame_height is None:
frame_height, frame_width = frame.shape[:2] # Get the frame dimensions from the first frame
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
count += 1
if count >= num_frames:
break
cap.release()
return frames, frame_width, frame_height # Return frames and frame dimensions
train_frames, FRAME_WIDTH, FRAME_HEIGHT = load_frames_from_video(TRAIN_VIDEO_FILE, num_frames=TRAIN_SAMPLES)
val_frames, _, _ = load_frames_from_video(VAL_VIDEO_FILE, num_frames=VAL_SAMPLES)
print("Number of training frames:", len(train_frames))
print("Number of validation frames:", len(val_frames))
def preprocess(frames):
frames = np.array(frames) / 255.0
return frames
train_frames = preprocess(train_frames)
val_frames = preprocess(val_frames)
print("training frames:", len(train_frames))
print("validation frames:", len(val_frames))
# Step 2: Model Architecture
model = VideoCompressionModel()
model.compile(loss='mean_squared_error', optimizer='adam', run_eagerly=True)
# Adjusting the input shape for training and validation
frame_height, frame_width = train_frames[0].shape[:2]
# Use the resized frames as target data
train_targets = train_frames
val_targets = val_frames
# Create the "models" directory if it doesn't exist
os.makedirs("models", exist_ok=True)
print("\nTraining the model...")
model.fit(
train_frames, [train_targets, tf.zeros_like(train_targets)],
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)])
)
print("\nTraining completed.")
# Step 3: Save the trained model
model.save('ai_rate_control_model.keras')
print("Model saved successfully!")