Improved model

This commit is contained in:
Jordon Brooks 2023-07-30 16:48:51 +01:00
parent 9167ff27d4
commit 60c6c97071
8 changed files with 327 additions and 112 deletions

View file

@ -1,4 +1,9 @@
# train_model.py
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import json
import numpy as np
import cv2
@ -7,82 +12,122 @@ import tensorflow as tf
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
from global_train import LOGGER
# Constants
BATCH_SIZE = 4
EPOCHS = 100
LEARNING_RATE = 0.000001
TRAIN_SAMPLES = 500
TRAIN_SAMPLES = 50
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
CONTINUE_TRAINING = None
EARLY_STOP = 10
def load_list(list_path):
with open(list_path, "r") as json_file:
video_details_list = json.load(json_file)
return video_details_list
def load_video_metadata(list_path):
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
try:
with open(list_path, "r") as json_file:
file = json.load(json_file)
LOGGER.trace(f"load_video_metadata returning: {file}")
return file
except FileNotFoundError:
LOGGER.error(f"Metadata file {list_path} not found.")
raise
except json.JSONDecodeError:
LOGGER.error(f"Error decoding JSON from {list_path}.")
raise
def load_video_from_list(list_path, samples = TRAIN_SAMPLES):
details_list = load_list(list_path)
all_details = []
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
"""
Load video samples from the metadata list.
Args:
- list_path (str): Path to the metadata JSON file.
- samples (int): Number of total samples to be extracted.
Returns:
- list: Extracted video samples.
"""
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" )
details_list = load_video_metadata(list_path)
all_samples = []
num_videos = len(details_list)
frames_per_video = int(samples / num_videos)
print(f"Loading {frames_per_video} frames across {num_videos} videos")
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
for video_details in details_list:
VIDEO_FILE = video_details["video_file"]
UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"]
CRF = video_details['crf'] / 63.0
PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
video_details['preset_speed'] = PRESET_SPEED
video_file = video_details["video_file"]
uncompressed_video_file = video_details["uncompressed_video_file"]
crf = video_details['crf'] / 63.0
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
video_details['preset_speed'] = preset_speed
compressed_frames, uncompressed_frames = [], []
frames = []
frames_compressed = []
cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE))
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE))
for _ in range(frames_per_video):
ret, frame_compressed = cap.read()
ret_uncompressed, frame = cap_uncompressed.read()
if not ret or not ret_uncompressed:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
frames.append(preprocess(frame))
frames_compressed.append(preprocess(frame_compressed))
for uncompressed_frame, compressed_frame in zip(frames, frames_compressed):
all_details.append({
"frame": uncompressed_frame,
"compressed_frame": compressed_frame,
"crf": CRF,
"preset_speed": PRESET_SPEED,
"video_file": VIDEO_FILE
})
cap.release()
cap_uncompressed.release()
try:
cap = cv2.VideoCapture(os.path.join("test_data/", video_file))
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", uncompressed_video_file))
if not cap.isOpened() or not cap_uncompressed.isOpened():
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}")
return all_details
for _ in range(frames_per_video):
ret, frame_compressed = cap.read()
ret_uncompressed, frame = cap_uncompressed.read()
def preprocess(frame):
if not ret or not ret_uncompressed:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
uncompressed_frames.append(normalize(frame))
compressed_frames.append(normalize(frame_compressed))
all_samples.extend({
"frame": frame,
"compressed_frame": frame_compressed,
"crf": crf,
"preset_speed": preset_speed,
"video_file": video_file
} for frame, frame_compressed in zip(uncompressed_frames, compressed_frames))
except Exception as e:
LOGGER.error(f"Error during video sample loading: {e}")
raise
finally:
cap.release()
cap_uncompressed.release()
return all_samples
def normalize(frame):
"""
Normalize pixel values of the frame to range [0, 1].
Args:
- frame (ndarray): Image frame.
Returns:
- ndarray: Normalized frame.
"""
LOGGER.trace(f"Normalizing frame")
return frame / 255.0
def save_model(model):
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
print("Model saved successfully!")
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def main():
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
@ -92,37 +137,35 @@ def main():
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
args = parser.parse_args()
# Use the parsed arguments in your script
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
TRAIN_SAMPLES = args.training_samples
LEARNING_RATE = args.learning_rate
CONTINUE_TRAINING = args.continue_training
print("Training configuration:")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Training samples: {TRAIN_SAMPLES}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Continue training from: {CONTINUE_TRAINING}")
all_video_details_train = load_video_from_list("test_data/training.json")
all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2)
train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE)
val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE)
if CONTINUE_TRAINING:
print("loading model:", CONTINUE_TRAINING)
model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file
# Display training configuration
LOGGER.info("Starting the training with the given configuration.")
LOGGER.info("Training configuration:")
LOGGER.info(f"Batch size: {args.batch_size}")
LOGGER.info(f"Epochs: {args.epochs}")
LOGGER.info(f"Training samples: {args.training_samples}")
LOGGER.info(f"Learning rate: {args.learning_rate}")
LOGGER.info(f"Continue training from: {args.continue_training}")
# Load training and validation samples
LOGGER.debug("Loading training and validation samples.")
training_samples = load_video_samples("test_data/training.json")
validation_samples = load_video_samples("test_data/validation.json", args.training_samples // 2)
train_generator = VideoDataGenerator(training_samples, args.batch_size)
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
# Load or initialize model
if args.continue_training:
model = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
# Define the optimizer with a specific learning rate
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True)
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
model.compile(loss='mean_squared_error', optimizer=optimizer)
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
save_weights_only=False,
@ -130,24 +173,25 @@ def main():
verbose=1,
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
#tf.config.run_functions_eagerly(True)
model.compile(loss='mean_squared_error', optimizer=optimizer)
early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True)
print("\nTraining the model...")
# Train the model
LOGGER.info("Starting model training.")
model.fit(
train_generator,
steps_per_epoch=len(train_generator),
epochs=EPOCHS,
epochs=args.epochs,
validation_data=val_generator,
validation_steps=len(val_generator),
callbacks=[early_stop, checkpoint_callback]
)
print("\nTraining completed!")
LOGGER.info("Model training completed.")
save_model(model)
if __name__ == "__main__":
main()
try:
main()
except Exception as e:
LOGGER.error(f"Unexpected error during training: {e}")
raise