This commit is contained in:
Jordon Brooks 2023-08-16 22:45:16 +01:00
parent 54fa90247a
commit 15d8e57da5
4 changed files with 56 additions and 12 deletions

View file

@ -2,7 +2,7 @@
import os import os
from featureExtraction import preprocess_frame from featureExtraction import preprocess_frame, psnr
from globalVars import PRESET_SPEED_CATEGORIES from globalVars import PRESET_SPEED_CATEGORIES
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -16,10 +16,10 @@ from video_compression_model import VideoCompressionModel
COMPRESSED_VIDEO_FILE = 'compressed_video.avi' COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 0 # Limit the number of frames processed MAX_FRAMES = 0 # Limit the number of frames processed
CRF = 51 CRF = 51
SPEED = PRESET_SPEED_CATEGORIES.index("veryslow") SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
# Load the trained model # Load the trained model
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel}) MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
# Load the uncompressed video # Load the uncompressed video
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv' UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
@ -36,8 +36,8 @@ def load_frame_from_video(video_file, frame_num):
def predict_frame(uncompressed_frame): def predict_frame(uncompressed_frame):
display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) #display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
cv2.imshow("uncomp", uncompressed_frame) #cv2.imshow("uncomp", uncompressed_frame)
frame = preprocess_frame(uncompressed_frame, CRF, SPEED) frame = preprocess_frame(uncompressed_frame, CRF, SPEED)

View file

@ -2,6 +2,11 @@
import cv2 import cv2
import numpy as np import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
from tensorflow.keras import backend as K
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
@ -38,6 +43,10 @@ def extract_histogram_features(frame, bins=64):
return np.array(feature_vector) return np.array(feature_vector)
def psnr(y_true, y_pred):
max_pixel = 1.0
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
def preprocess_frame(frame, crf, speed): def preprocess_frame(frame, crf, speed):
# Check frame dimensions and resize if necessary # Check frame dimensions and resize if necessary

View file

@ -46,5 +46,17 @@
"original_video_file": "Scene9.mkv", "original_video_file": "Scene9.mkv",
"crf": 15, "crf": 15,
"preset_speed": "slow" "preset_speed": "slow"
},
{
"compressed_video_file": "Scene10_x264_crf-23_preset-ultrafast.mkv",
"original_video_file": "Scene10.mkv",
"crf": 23,
"preset_speed": "ultrafast"
},
{
"compressed_video_file": "Scene11_x264_crf-42_preset-medium.mkv",
"original_video_file": "Scene11.mkv",
"crf": 42,
"preset_speed": "medium"
} }
] ]

View file

@ -1,8 +1,16 @@
# train_model.py
"""
TODO:
- Add more different videos with different parateters into the training set.
- Add different scenes with the same parameters
"""
import argparse import argparse
import json import json
import os import os
import cv2
import numpy as np from featureExtraction import psnr
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -16,10 +24,12 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants # Constants
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 100 EPOCHS = 100
LEARNING_RATE = 0.000001 LEARNING_RATE = 0.001
DECAY_STEPS = 40
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf" MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints" MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10 EARLY_STOP = 5
def save_model(model): def save_model(model):
try: try:
@ -58,7 +68,7 @@ def load_video_metadata(list_path):
def main(): def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
# Argument parsing # Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.") parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.') parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
@ -66,12 +76,16 @@ def main():
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.') parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.') parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.') parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
args = parser.parse_args() args = parser.parse_args()
BATCH_SIZE = args.batch_size BATCH_SIZE = args.batch_size
EPOCHS = args.epochs EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate LEARNING_RATE = args.learning_rate
MAX_FRAMES = args.max_frames MAX_FRAMES = args.max_frames
DECAY_RATE = args.decay_rate
DECAY_STEPS = args.decay_steps
# Display training configuration # Display training configuration
LOGGER.info("Starting the training with the given configuration.") LOGGER.info("Starting the training with the given configuration.")
@ -96,11 +110,20 @@ def main():
model = tf.keras.models.load_model(args.continue_training) model = tf.keras.models.load_model(args.continue_training)
else: else:
model = VideoCompressionModel() model = VideoCompressionModel()
# Define exponential decay schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=LEARNING_RATE,
decay_steps=DECAY_STEPS,
decay_rate=DECAY_RATE,
staircase=False
)
# Set optimizer and compile the model # Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE) optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mean_squared_error', optimizer=optimizer) model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
# Define checkpoints and early stopping # Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(