Working GPU model

This commit is contained in:
Jordon Brooks 2023-07-30 11:49:19 +01:00
parent 5085c87300
commit dea59068fb
3 changed files with 190 additions and 108 deletions

View file

@ -2,106 +2,168 @@ import os
import json
import numpy as np
import cv2
import argparse
import tensorflow as tf
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
print(tf.config.list_physical_devices('GPU'))
print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
# Constants
BATCH_SIZE = 8
EPOCHS = 50
TRAIN_SAMPLES = 5
BATCH_SIZE = 16
EPOCHS = 40
LEARNING_RATE = 0.00001
TRAIN_SAMPLES = 100
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
CONTINUE_TRAINING = None
def load_list(list_path):
with open(list_path, "r") as json_file:
video_details_list = json.load(json_file)
return video_details_list
def load_frame_from_video(video_file):
print("Extracting video frame...")
cap = cv2.VideoCapture(video_file)
ret, frame = cap.read()
if not ret:
return None
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cap.release()
return frame
def preprocess(frame):
return frame / 255.0
def save_model(model, file):
os.makedirs("models", exist_ok=True)
model.save(os.path.join("models/", file))
print("Model saved successfully!")
def load_video_from_list(list_path):
details_list = load_list(list_path)
all_frames = []
all_details = []
num_videos = len(details_list)
frames_per_video = int(TRAIN_SAMPLES / num_videos)
print(f"Loading {frames_per_video} across {num_videos} videos")
for video_details in details_list:
VIDEO_FILE = video_details["video_file"]
UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"]
CRF = video_details['crf'] / 63.0
PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
video_details['preset_speed'] = PRESET_SPEED
frame = load_frame_from_video(os.path.join("test_data/", VIDEO_FILE))
if frame is not None:
all_frames.append(preprocess(frame))
frames = []
frames_compressed = []
cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE))
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE))
for _ in range(frames_per_video):
ret, frame_compressed = cap.read()
ret_uncompressed, frame = cap_uncompressed.read()
if not ret or not ret_uncompressed:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
frames.append(preprocess(frame))
frames_compressed.append(preprocess(frame_compressed))
for uncompressed_frame, compressed_frame in zip(frames, frames_compressed):
all_details.append({
"frame": frame,
"frame": uncompressed_frame,
"compressed_frame": compressed_frame,
"crf": CRF,
"preset_speed": PRESET_SPEED,
"video_file": VIDEO_FILE
})
cap.release()
cap_uncompressed.release()
return all_details
def preprocess(frame):
return frame / 255.0
def save_model(model):
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
print("Model saved successfully!")
def main():
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.')
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
args = parser.parse_args()
# Use the parsed arguments in your script
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
TRAIN_SAMPLES = args.training_samples
LEARNING_RATE = args.learning_rate
CONTINUE_TRAINING = args.continue_training
print("Training configuration:")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Training samples: {TRAIN_SAMPLES}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Continue training from: {CONTINUE_TRAINING}")
all_video_details_train = load_video_from_list("test_data/training.json")
all_video_details_val = load_video_from_list("test_data/validation.json")
model = VideoCompressionModel(NUM_CHANNELS)
model.compile(loss='mean_squared_error', optimizer='adam')
early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
# Prepare data
all_train_frames = []
all_val_frames = []
all_crf_train = []
all_crf_val = []
all_preset_speed_train = []
all_preset_speed_val = []
for video_details_train, video_details_val in zip(all_video_details_train, all_video_details_val):
all_train_frames.append(video_details_train["frame"])
all_val_frames.append(video_details_val["frame"])
all_crf_train.append(video_details_train['crf'])
all_crf_val.append(video_details_val['crf'])
all_preset_speed_train.append(video_details_train['preset_speed'])
all_preset_speed_val.append(video_details_val['preset_speed'])
all_train_frames = [video_details["frame"] for video_details in all_video_details_train]
all_train_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_train]
all_val_frames = [video_details["frame"] for video_details in all_video_details_val]
all_val_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_val]
all_crf_train = [video_details['crf'] for video_details in all_video_details_train]
all_crf_val = [video_details['crf'] for video_details in all_video_details_val]
all_preset_speed_train = [video_details['preset_speed'] for video_details in all_video_details_train]
all_preset_speed_val = [video_details['preset_speed'] for video_details in all_video_details_val]
# Convert lists to numpy arrays
all_train_frames = np.array(all_train_frames)
all_train_compressed_frames = np.array(all_train_compressed_frames)
all_val_frames = np.array(all_val_frames)
all_val_compressed_frames = np.array(all_val_compressed_frames)
all_crf_train = np.array(all_crf_train)
all_crf_val = np.array(all_crf_val)
all_preset_speed_train = np.array(all_preset_speed_train)
all_preset_speed_val = np.array(all_preset_speed_val)
if CONTINUE_TRAINING:
print("loading model:", CONTINUE_TRAINING)
model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file
else:
model = VideoCompressionModel()
# Define the optimizer with a specific learning rate
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
save_weights_only=False,
save_best_only=False,
verbose=1,
save_format="tf"
)
print("\nTraining the model on frame pairs...")
#tf.config.run_functions_eagerly(True)
model.compile(loss='mean_squared_error', optimizer=optimizer)
early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True)
print("\nTraining the model...")
model.fit(
{"frame": all_train_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train},
all_val_frames, # Target is the compressed frame
{"uncompressed_frame": all_train_frames, "compressed_frame": all_train_compressed_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train},
all_train_compressed_frames, # Target is the compressed frame
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=({"frame": all_val_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_frames),
callbacks=[early_stop]
validation_data=({"uncompressed_frame": all_val_frames, "compressed_frame": all_val_compressed_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_compressed_frames),
callbacks=[early_stop, checkpoint_callback]
)
print("\nTraining completed!")
save_model(model, 'model.keras')
save_model(model)
if __name__ == "__main__":
main()