Better
This commit is contained in:
parent
9f34cf8074
commit
2b4664fcbb
5 changed files with 34 additions and 11 deletions
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from featureExtraction import preprocess_frame, psnr, scale_crf, scale_speed_preset
|
from featureExtraction import combined, preprocess_frame, psnr, scale_crf, scale_speed_preset, ssim
|
||||||
from globalVars import PRESET_SPEED_CATEGORIES
|
from globalVars import PRESET_SPEED_CATEGORIES
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
@ -14,15 +14,15 @@ from video_compression_model import VideoCompressionModel, combine_batch
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
MAX_FRAMES = 200 # Limit the number of frames processed
|
||||||
CRF = 0
|
CRF = 51
|
||||||
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
||||||
|
|
||||||
# Load the trained model
|
# Load the trained model
|
||||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
|
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined})
|
||||||
|
|
||||||
# Load the uncompressed video
|
# Load the uncompressed video
|
||||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/B4_t02.mkv'
|
UNCOMPRESSED_VIDEO_FILE = 'test_data/x264_crf-5_preset-veryslow.mkv'
|
||||||
|
|
||||||
def load_frame_from_video(video_file, frame_num):
|
def load_frame_from_video(video_file, frame_num):
|
||||||
cap = cv2.VideoCapture(video_file)
|
cap = cv2.VideoCapture(video_file)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import os
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
from tensorflow.keras import backend as K
|
from tensorflow.keras import backend as K
|
||||||
|
|
||||||
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
|
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
|
||||||
|
@ -49,11 +50,20 @@ def extract_histogram_features(frame, bins=64):
|
||||||
|
|
||||||
return np.array(feature_vector)
|
return np.array(feature_vector)
|
||||||
|
|
||||||
|
|
||||||
def psnr(y_true, y_pred):
|
def psnr(y_true, y_pred):
|
||||||
max_pixel = 1.0
|
max_pixel = 1.0
|
||||||
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
||||||
|
|
||||||
|
|
||||||
|
def ssim(y_true, y_pred):
|
||||||
|
return (tf.image.ssim(y_true, y_pred, max_val=1.0) + 1) * 50 # Normalize SSIM from [-1, 1] to [0, 100]
|
||||||
|
|
||||||
|
|
||||||
|
def combined(y_true, y_pred):
|
||||||
|
return (psnr(y_true, y_pred) + ssim(y_true, y_pred)) / 2
|
||||||
|
|
||||||
|
|
||||||
def preprocess_frame(frame, resize=True):
|
def preprocess_frame(frame, resize=True):
|
||||||
#Preprocesses a single frame, cropping it if needed
|
#Preprocesses a single frame, cropping it if needed
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import log
|
import log
|
||||||
|
|
||||||
LOGGER = log.Logger(level="DEBUG", logfile="training.log", reset_logfile=True)
|
LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True)
|
||||||
|
|
||||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||||
|
|
|
@ -9,8 +9,9 @@ TODO:
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
from featureExtraction import psnr
|
from featureExtraction import combined, psnr, ssim
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
|
@ -110,17 +111,29 @@ def main():
|
||||||
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
||||||
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
|
LOGGER.info(f"Max Frames: {MAX_FRAMES}")
|
||||||
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||||
|
LOGGER.info(f"Decay Steps: {DECAY_STEPS}")
|
||||||
|
LOGGER.info(f"Decay Rate: {DECAY_RATE}")
|
||||||
|
|
||||||
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||||
|
|
||||||
# Load all video metadata
|
# Load all video metadata
|
||||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||||
|
|
||||||
|
# Specify the random seed
|
||||||
|
random_seed = 4576 # You can change this to any desired value
|
||||||
|
|
||||||
|
# Shuffle the data using the specified seed
|
||||||
|
random.shuffle(all_videos, random.seed(random_seed))
|
||||||
|
|
||||||
# Split into training and validation
|
# Split into training and validation
|
||||||
split_index = int(0.8 * len(all_videos))
|
split_index = int(0.6 * len(all_videos))
|
||||||
training_videos = all_videos[:split_index]
|
training_videos = all_videos[:split_index]
|
||||||
validation_videos = all_videos[split_index:]
|
validation_videos = all_videos[split_index:]
|
||||||
|
|
||||||
|
LOGGER.info(f"Training videos: {training_videos}")
|
||||||
|
LOGGER.info(f"==================================")
|
||||||
|
LOGGER.info(f"Validation videos: {validation_videos}")
|
||||||
|
|
||||||
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
|
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
|
||||||
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
|
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
|
||||||
|
|
||||||
|
@ -142,7 +155,7 @@ def main():
|
||||||
|
|
||||||
# Set optimizer and compile the model
|
# Set optimizer and compile the model
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||||
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
|
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||||
|
|
||||||
# Define checkpoints and early stopping
|
# Define checkpoints and early stopping
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
@ -152,7 +165,7 @@ def main():
|
||||||
verbose=1,
|
verbose=1,
|
||||||
save_format="tf"
|
save_format="tf"
|
||||||
)
|
)
|
||||||
early_stop = EarlyStopping(monitor='val_psnr', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||||
|
|
||||||
# Custom garbage collection callback
|
# Custom garbage collection callback
|
||||||
gc_callback = GarbageCollectorCallback()
|
gc_callback = GarbageCollectorCallback()
|
||||||
|
|
|
@ -77,7 +77,7 @@ def frame_generator(videos, max_frames=None):
|
||||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||||
|
|
||||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
uncompressed_combined = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||||
|
|
||||||
yield uncompressed_combined, compressed_combined
|
yield uncompressed_combined, compressed_combined
|
||||||
|
|
||||||
|
|
Reference in a new issue