This commit is contained in:
Jordon Brooks 2023-08-19 01:45:58 +01:00
parent 9f34cf8074
commit 2b4664fcbb
5 changed files with 34 additions and 11 deletions

View file

@ -2,7 +2,7 @@
import os
from featureExtraction import preprocess_frame, psnr, scale_crf, scale_speed_preset
from featureExtraction import combined, preprocess_frame, psnr, scale_crf, scale_speed_preset, ssim
from globalVars import PRESET_SPEED_CATEGORIES
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -14,15 +14,15 @@ from video_compression_model import VideoCompressionModel, combine_batch
# Constants
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 0 # Limit the number of frames processed
CRF = 0
MAX_FRAMES = 200 # Limit the number of frames processed
CRF = 51
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
# Load the trained model
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined})
# Load the uncompressed video
UNCOMPRESSED_VIDEO_FILE = 'test_data/B4_t02.mkv'
UNCOMPRESSED_VIDEO_FILE = 'test_data/x264_crf-5_preset-veryslow.mkv'
def load_frame_from_video(video_file, frame_num):
cap = cv2.VideoCapture(video_file)

View file

@ -6,6 +6,7 @@ import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow as tf
from tensorflow.keras import backend as K
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
@ -49,11 +50,20 @@ def extract_histogram_features(frame, bins=64):
return np.array(feature_vector)
def psnr(y_true, y_pred):
max_pixel = 1.0
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
def ssim(y_true, y_pred):
return (tf.image.ssim(y_true, y_pred, max_val=1.0) + 1) * 50 # Normalize SSIM from [-1, 1] to [0, 100]
def combined(y_true, y_pred):
return (psnr(y_true, y_pred) + ssim(y_true, y_pred)) / 2
def preprocess_frame(frame, resize=True):
#Preprocesses a single frame, cropping it if needed

View file

@ -2,7 +2,7 @@
import log
LOGGER = log.Logger(level="DEBUG", logfile="training.log", reset_logfile=True)
LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True)
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)

View file

@ -9,8 +9,9 @@ TODO:
import argparse
import json
import os
import random
from featureExtraction import psnr
from featureExtraction import combined, psnr, ssim
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
@ -110,17 +111,29 @@ def main():
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
LOGGER.info(f"Decay Steps: {DECAY_STEPS}")
LOGGER.info(f"Decay Rate: {DECAY_RATE}")
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
# Specify the random seed
random_seed = 4576 # You can change this to any desired value
# Shuffle the data using the specified seed
random.shuffle(all_videos, random.seed(random_seed))
# Split into training and validation
split_index = int(0.8 * len(all_videos))
split_index = int(0.6 * len(all_videos))
training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:]
LOGGER.info(f"Training videos: {training_videos}")
LOGGER.info(f"==================================")
LOGGER.info(f"Validation videos: {validation_videos}")
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
@ -142,7 +155,7 @@ def main():
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
@ -152,7 +165,7 @@ def main():
verbose=1,
save_format="tf"
)
early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()

View file

@ -77,7 +77,7 @@ def frame_generator(videos, max_frames=None):
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
uncompressed_combined = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
yield uncompressed_combined, compressed_combined