This commit is contained in:
Jordon Brooks 2023-08-16 22:45:16 +01:00
parent 54fa90247a
commit 15d8e57da5
4 changed files with 56 additions and 12 deletions

View file

@ -2,7 +2,7 @@
import os
from featureExtraction import preprocess_frame
from featureExtraction import preprocess_frame, psnr
from globalVars import PRESET_SPEED_CATEGORIES
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -16,10 +16,10 @@ from video_compression_model import VideoCompressionModel
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 0 # Limit the number of frames processed
CRF = 51
SPEED = PRESET_SPEED_CATEGORIES.index("veryslow")
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
# Load the trained model
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
# Load the uncompressed video
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
@ -36,8 +36,8 @@ def load_frame_from_video(video_file, frame_num):
def predict_frame(uncompressed_frame):
display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
cv2.imshow("uncomp", uncompressed_frame)
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
#cv2.imshow("uncomp", uncompressed_frame)
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)

View file

@ -2,6 +2,11 @@
import cv2
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
from tensorflow.keras import backend as K
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
@ -38,6 +43,10 @@ def extract_histogram_features(frame, bins=64):
return np.array(feature_vector)
def psnr(y_true, y_pred):
max_pixel = 1.0
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
def preprocess_frame(frame, crf, speed):
# Check frame dimensions and resize if necessary

View file

@ -46,5 +46,17 @@
"original_video_file": "Scene9.mkv",
"crf": 15,
"preset_speed": "slow"
},
{
"compressed_video_file": "Scene10_x264_crf-23_preset-ultrafast.mkv",
"original_video_file": "Scene10.mkv",
"crf": 23,
"preset_speed": "ultrafast"
},
{
"compressed_video_file": "Scene11_x264_crf-42_preset-medium.mkv",
"original_video_file": "Scene11.mkv",
"crf": 42,
"preset_speed": "medium"
}
]

View file

@ -1,8 +1,16 @@
# train_model.py
"""
TODO:
- Add more different videos with different parateters into the training set.
- Add different scenes with the same parameters
"""
import argparse
import json
import os
import cv2
import numpy as np
from featureExtraction import psnr
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -16,10 +24,12 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants
BATCH_SIZE = 16
EPOCHS = 100
LEARNING_RATE = 0.000001
LEARNING_RATE = 0.001
DECAY_STEPS = 40
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
EARLY_STOP = 5
def save_model(model):
try:
@ -58,7 +68,7 @@ def load_video_metadata(list_path):
def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
@ -66,12 +76,16 @@ def main():
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
args = parser.parse_args()
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rate
MAX_FRAMES = args.max_frames
DECAY_RATE = args.decay_rate
DECAY_STEPS = args.decay_steps
# Display training configuration
LOGGER.info("Starting the training with the given configuration.")
@ -96,11 +110,20 @@ def main():
model = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
# Define exponential decay schedule
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=LEARNING_RATE,
decay_steps=DECAY_STEPS,
decay_rate=DECAY_RATE,
staircase=False
)
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss='mean_squared_error', optimizer=optimizer)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(