This commit is contained in:
Jordon Brooks 2023-08-13 02:06:45 +01:00
parent ed5eb91578
commit 9ae5921e2b
3 changed files with 153 additions and 210 deletions

View file

@ -1,29 +1,77 @@
# train_model.py
import math
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import json
import argparse
import os
import cv2
import numpy as np
from train_model_V2 import VideoCompressionModel
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow as tf
from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from global_train import LOGGER
# Constants
BATCH_SIZE = 4
EPOCHS = 100
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 0.01
TRAIN_SAMPLES = 100
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
NUM_CHANNELS = 3
WIDTH = 640
HEIGHT = 360
def save_model(model):
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def extract_edge_features(frame):
"""
Extract edge features using Canny edge detection.
Args:
- frame (ndarray): Image frame.
Returns:
- ndarray: Edge feature map.
"""
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
return edges.astype(np.float32) / 255.0
def extract_histogram_features(frame, bins=64):
"""
Extract histogram features from a frame.
Args:
- frame (ndarray): Image frame.
- bins (int): Number of bins for the histogram.
Returns:
- ndarray: Normalized histogram feature vector.
"""
histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255])
return histogram.astype(np.float32) / frame.size
def load_video_metadata(list_path):
"""
Load video metadata from a JSON file.
Args:
- json_path (str): Path to the JSON file containing video metadata.
Returns:
- list: List of dictionaries, each containing video details.
"""
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
try:
with open(list_path, "r") as json_file:
@ -36,42 +84,47 @@ def load_video_metadata(list_path):
except json.JSONDecodeError:
LOGGER.error(f"Error decoding JSON from {list_path}.")
raise
def data_generator(videos, batch_size):
while True:
for video_details in videos:
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
cap = cv2.VideoCapture(video_path)
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})")
details_list = load_video_metadata(list_path)
all_samples = []
num_videos = len(details_list)
frames_per_video = math.ceil(samples / num_videos)
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
for video_details in details_list:
compressed_video_file = video_details["compressed_video_file"]
original_video_file = video_details["original_video_file"]
crf = video_details['crf'] / 51
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
video_details['preset_speed'] = preset_speed
feature_batch = []
compressed_frame_batch = []
# Store video details without loading frames
all_samples.extend({
"frames_per_video": frames_per_video,
"crf": crf,
"preset_speed": preset_speed,
"compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file),
"original_video_file": os.path.join(os.path.dirname(list_path), original_video_file)
} for _ in range(frames_per_video))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
return all_samples
# Check frame dimensions and resize if necessary
if frame.shape[:2] != (HEIGHT, WIDTH):
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
# Extract features
edge_feature = extract_edge_features(frame)
histogram_feature = extract_histogram_features(frame)
histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape
combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1)
compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1]
feature_batch.append(combined_feature)
compressed_frame_batch.append(compressed_frame)
if len(feature_batch) == batch_size:
yield (np.array(feature_batch), np.array(compressed_frame_batch))
feature_batch = []
compressed_frame_batch = []
cap.release()
# If there are frames left that don't fill a whole batch, send them anyway
if len(feature_batch) > 0:
yield (np.array(feature_batch), np.array(compressed_frame_batch))
def save_model(model):
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def main():
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
@ -100,25 +153,22 @@ def main():
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
LOGGER.trace("Hello, World!")
# Load training and validation samples
LOGGER.debug("Loading training and validation samples.")
training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES)
validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10))
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
train_generator = VideoDataGenerator(training_samples, BATCH_SIZE)
val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE)
# Split into training and validation
split_index = int(0.8 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
# Load or initialize model
if args.continue_training:
model = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
model = VideoCompressionModel()
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss='mean_squared_error', optimizer=optimizer)
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
@ -129,23 +179,31 @@ def main():
)
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Calculate steps per epoch for training and validation
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
total_frames_train = average_frames_per_video * len(training_videos)
total_frames_validation = average_frames_per_video * len(validation_videos)
steps_per_epoch_train = total_frames_train // BATCH_SIZE
steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
# Train the model
LOGGER.info("Starting model training.")
model.fit(
train_generator,
steps_per_epoch=len(train_generator),
epochs=EPOCHS,
validation_data=val_generator,
validation_steps=len(val_generator),
data_generator(training_videos, BATCH_SIZE),
epochs=EPOCHS,
steps_per_epoch=steps_per_epoch_train,
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
validation_steps=steps_per_epoch_validation, # Add validation steps here
callbacks=[early_stop, checkpoint_callback]
)
LOGGER.info("Model training completed.")
save_model(model)
if __name__ == "__main__":
try:
main()
except Exception as e:
LOGGER.error(f"Unexpected error during training: {e}")
raise
raise