Initial Commit
This commit is contained in:
parent
645b6c29f7
commit
c7306a9d48
4 changed files with 191 additions and 159 deletions
91
train_model.py
Normal file
91
train_model.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
import os
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from video_compression_model import VideoCompressionModel
|
||||
|
||||
# Constants
|
||||
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
||||
BATCH_SIZE = 32 # Batch size used during training
|
||||
EPOCHS = 20 # Number of training epochs
|
||||
CHECKPOINT_FILEPATH = "models/checkpoint-{epoch:02d}.keras"
|
||||
|
||||
# Step 1: Data Preparation
|
||||
TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name
|
||||
VAL_VIDEO_FILE = 'training_video.mkv' # The validation video file name
|
||||
TRAIN_SAMPLES = 2 # Number of video frames used for training
|
||||
VAL_SAMPLES = 2 # Number of video frames used for validation
|
||||
|
||||
def load_frames_from_video(video_file, num_frames):
|
||||
print("Extracting video frames...")
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
frames = []
|
||||
count = 0
|
||||
frame_width, frame_height = None, None # Initialize the frame dimensions
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if frame_width is None or frame_height is None:
|
||||
frame_height, frame_width = frame.shape[:2] # Get the frame dimensions from the first frame
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame)
|
||||
count += 1
|
||||
if count >= num_frames:
|
||||
break
|
||||
cap.release()
|
||||
return frames, frame_width, frame_height # Return frames and frame dimensions
|
||||
|
||||
train_frames, FRAME_WIDTH, FRAME_HEIGHT = load_frames_from_video(TRAIN_VIDEO_FILE, num_frames=TRAIN_SAMPLES)
|
||||
val_frames, _, _ = load_frames_from_video(VAL_VIDEO_FILE, num_frames=VAL_SAMPLES)
|
||||
|
||||
|
||||
print("Number of training frames:", len(train_frames))
|
||||
print("Number of validation frames:", len(val_frames))
|
||||
|
||||
def preprocess(frames):
|
||||
frames = np.array(frames) / 255.0
|
||||
return frames
|
||||
|
||||
train_frames = preprocess(train_frames)
|
||||
val_frames = preprocess(val_frames)
|
||||
|
||||
print("training frames:", len(train_frames))
|
||||
print("validation frames:", len(val_frames))
|
||||
|
||||
# Step 2: Model Architecture
|
||||
model = VideoCompressionModel()
|
||||
|
||||
model.compile(loss='mean_squared_error', optimizer='adam', run_eagerly=True)
|
||||
|
||||
# Adjusting the input shape for training and validation
|
||||
frame_height, frame_width = train_frames[0].shape[:2]
|
||||
|
||||
# Use the resized frames as target data
|
||||
train_targets = train_frames
|
||||
val_targets = val_frames
|
||||
|
||||
# Create the "models" directory if it doesn't exist
|
||||
os.makedirs("models", exist_ok=True)
|
||||
|
||||
# Create the ModelCheckpoint callback
|
||||
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=CHECKPOINT_FILEPATH,
|
||||
save_weights_only=False, # Save the entire model (including architecture)
|
||||
monitor='val_loss', # Metric to monitor for saving the best model (optional)
|
||||
save_best_only=True # Save only the best model based on the monitored metric (optional)
|
||||
)
|
||||
|
||||
print("\nTraining the model...")
|
||||
model.fit(
|
||||
train_frames, [train_targets, tf.zeros_like(train_targets)],
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS,
|
||||
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)]),
|
||||
callbacks=[model_checkpoint_callback] # Add the ModelCheckpoint callback
|
||||
)
|
||||
print("\nTraining completed.")
|
||||
|
||||
# Step 3: Save the trained model
|
||||
model.save('ai_rate_control_model.keras')
|
||||
print("Model saved successfully!")
|
Reference in a new issue