81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
import os
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
import cv2
|
|
from video_compression_model import VideoCompressionModel
|
|
|
|
# Constants
|
|
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
|
BATCH_SIZE = 32 # Batch size used during training
|
|
EPOCHS = 20 # Number of training epochs
|
|
|
|
# Step 1: Data Preparation
|
|
TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name
|
|
VAL_VIDEO_FILE = 'training_video.mkv' # The validation video file name
|
|
TRAIN_SAMPLES = 2 # Number of video frames used for training
|
|
VAL_SAMPLES = 2 # Number of video frames used for validation
|
|
|
|
def load_frames_from_video(video_file, num_frames):
|
|
print("Extracting video frames...")
|
|
cap = cv2.VideoCapture(video_file)
|
|
frames = []
|
|
count = 0
|
|
frame_width, frame_height = None, None # Initialize the frame dimensions
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
if frame_width is None or frame_height is None:
|
|
frame_height, frame_width = frame.shape[:2] # Get the frame dimensions from the first frame
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frames.append(frame)
|
|
count += 1
|
|
if count >= num_frames:
|
|
break
|
|
cap.release()
|
|
return frames, frame_width, frame_height # Return frames and frame dimensions
|
|
|
|
train_frames, FRAME_WIDTH, FRAME_HEIGHT = load_frames_from_video(TRAIN_VIDEO_FILE, num_frames=TRAIN_SAMPLES)
|
|
val_frames, _, _ = load_frames_from_video(VAL_VIDEO_FILE, num_frames=VAL_SAMPLES)
|
|
|
|
|
|
print("Number of training frames:", len(train_frames))
|
|
print("Number of validation frames:", len(val_frames))
|
|
|
|
def preprocess(frames):
|
|
frames = np.array(frames) / 255.0
|
|
return frames
|
|
|
|
train_frames = preprocess(train_frames)
|
|
val_frames = preprocess(val_frames)
|
|
|
|
print("training frames:", len(train_frames))
|
|
print("validation frames:", len(val_frames))
|
|
|
|
# Step 2: Model Architecture
|
|
model = VideoCompressionModel()
|
|
|
|
model.compile(loss='mean_squared_error', optimizer='adam', run_eagerly=True)
|
|
|
|
# Adjusting the input shape for training and validation
|
|
frame_height, frame_width = train_frames[0].shape[:2]
|
|
|
|
# Use the resized frames as target data
|
|
train_targets = train_frames
|
|
val_targets = val_frames
|
|
|
|
# Create the "models" directory if it doesn't exist
|
|
os.makedirs("models", exist_ok=True)
|
|
|
|
print("\nTraining the model...")
|
|
model.fit(
|
|
train_frames, [train_targets, tf.zeros_like(train_targets)],
|
|
batch_size=BATCH_SIZE,
|
|
epochs=EPOCHS,
|
|
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)])
|
|
)
|
|
print("\nTraining completed.")
|
|
|
|
# Step 3: Save the trained model
|
|
model.save('ai_rate_control_model.keras')
|
|
print("Model saved successfully!")
|