updated
This commit is contained in:
parent
93ccce5ec1
commit
ed5eb91578
6 changed files with 181 additions and 171 deletions
128
train_model.py
128
train_model.py
|
@ -1,15 +1,14 @@
|
|||
# train_model.py
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
||||
from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
from global_train import LOGGER
|
||||
|
@ -17,13 +16,12 @@ from global_train import LOGGER
|
|||
# Constants
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.000001
|
||||
LEARNING_RATE = 0.01
|
||||
TRAIN_SAMPLES = 100
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
WIDTH = 638
|
||||
HEIGHT = 360
|
||||
|
||||
|
||||
def load_video_metadata(list_path):
|
||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||
|
@ -40,92 +38,30 @@ def load_video_metadata(list_path):
|
|||
raise
|
||||
|
||||
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
||||
"""
|
||||
Load video samples from the metadata list.
|
||||
|
||||
Args:
|
||||
- list_path (str): Path to the metadata JSON file.
|
||||
- samples (int): Number of total samples to be extracted.
|
||||
|
||||
Returns:
|
||||
- list: Extracted video samples.
|
||||
"""
|
||||
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" )
|
||||
|
||||
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})")
|
||||
details_list = load_video_metadata(list_path)
|
||||
all_samples = []
|
||||
num_videos = len(details_list)
|
||||
frames_per_video = int(samples / num_videos)
|
||||
|
||||
frames_per_video = math.ceil(samples / num_videos)
|
||||
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
|
||||
|
||||
for video_details in details_list:
|
||||
video_file = video_details["video_file"]
|
||||
uncompressed_video_file = video_details["uncompressed_video_file"]
|
||||
crf = video_details['crf'] / 63.0
|
||||
compressed_video_file = video_details["compressed_video_file"]
|
||||
original_video_file = video_details["original_video_file"]
|
||||
crf = video_details['crf'] / 51
|
||||
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
||||
video_details['preset_speed'] = preset_speed
|
||||
|
||||
compressed_frames, uncompressed_frames = [], []
|
||||
|
||||
try:
|
||||
cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file))
|
||||
|
||||
if not cap.isOpened() or not cap_uncompressed.isOpened():
|
||||
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}")
|
||||
|
||||
for _ in range(frames_per_video):
|
||||
ret, frame_compressed = cap.read()
|
||||
ret_uncompressed, frame = cap_uncompressed.read()
|
||||
|
||||
if not ret or not ret_uncompressed:
|
||||
continue
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if frame.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.warn(f"Resizing video: {video_file}")
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
if frame_compressed.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.warn(f"Resizing video: {uncompressed_video_file}")
|
||||
frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
||||
|
||||
uncompressed_frames.append(normalize(frame))
|
||||
compressed_frames.append(normalize(frame_compressed))
|
||||
|
||||
all_samples.extend({
|
||||
"frame": frame,
|
||||
"compressed_frame": frame_compressed,
|
||||
"crf": crf,
|
||||
"preset_speed": preset_speed,
|
||||
"video_file": video_file
|
||||
} for frame, frame_compressed in zip(uncompressed_frames, compressed_frames))
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error during video sample loading: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
cap_uncompressed.release()
|
||||
# Store video details without loading frames
|
||||
all_samples.extend({
|
||||
"frames_per_video": frames_per_video,
|
||||
"crf": crf,
|
||||
"preset_speed": preset_speed,
|
||||
"compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file),
|
||||
"original_video_file": os.path.join(os.path.dirname(list_path), original_video_file)
|
||||
} for _ in range(frames_per_video))
|
||||
|
||||
return all_samples
|
||||
|
||||
def normalize(frame):
|
||||
"""
|
||||
Normalize pixel values of the frame to range [0, 1].
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
|
||||
Returns:
|
||||
- ndarray: Normalized frame.
|
||||
"""
|
||||
LOGGER.trace(f"Normalizing frame")
|
||||
return frame / 255.0
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
|
@ -138,6 +74,7 @@ def save_model(model):
|
|||
raise
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||
|
@ -147,23 +84,30 @@ def main():
|
|||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
BATCH_SIZE = args.batch_size
|
||||
EPOCHS = args.epochs
|
||||
TRAIN_SAMPLES = args.training_samples
|
||||
LEARNING_RATE = args.learning_rate
|
||||
|
||||
# Display training configuration
|
||||
LOGGER.info("Starting the training with the given configuration.")
|
||||
LOGGER.info("Training configuration:")
|
||||
LOGGER.info(f"Batch size: {args.batch_size}")
|
||||
LOGGER.info(f"Epochs: {args.epochs}")
|
||||
LOGGER.info(f"Training samples: {args.training_samples}")
|
||||
LOGGER.info(f"Learning rate: {args.learning_rate}")
|
||||
LOGGER.info(f"Continue training from: {args.continue_training}")
|
||||
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
||||
LOGGER.info(f"Epochs: {EPOCHS}")
|
||||
LOGGER.info(f"Training samples: {TRAIN_SAMPLES}")
|
||||
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
||||
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||
|
||||
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||
|
||||
# Load training and validation samples
|
||||
LOGGER.debug("Loading training and validation samples.")
|
||||
training_samples = load_video_samples("test_data/training/training.json")
|
||||
validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2)
|
||||
training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES)
|
||||
validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10))
|
||||
|
||||
train_generator = VideoDataGenerator(training_samples, args.batch_size)
|
||||
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
|
||||
train_generator = VideoDataGenerator(training_samples, BATCH_SIZE)
|
||||
val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE)
|
||||
|
||||
# Load or initialize model
|
||||
if args.continue_training:
|
||||
|
@ -172,7 +116,7 @@ def main():
|
|||
model = VideoCompressionModel()
|
||||
|
||||
# Set optimizer and compile the model
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
||||
|
||||
# Define checkpoints and early stopping
|
||||
|
@ -190,7 +134,7 @@ def main():
|
|||
model.fit(
|
||||
train_generator,
|
||||
steps_per_epoch=len(train_generator),
|
||||
epochs=args.epochs,
|
||||
epochs=EPOCHS,
|
||||
validation_data=val_generator,
|
||||
validation_steps=len(val_generator),
|
||||
callbacks=[early_stop, checkpoint_callback]
|
||||
|
|
Reference in a new issue