Update
This commit is contained in:
parent
54fa90247a
commit
15d8e57da5
4 changed files with 56 additions and 12 deletions
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from featureExtraction import preprocess_frame
|
from featureExtraction import preprocess_frame, psnr
|
||||||
from globalVars import PRESET_SPEED_CATEGORIES
|
from globalVars import PRESET_SPEED_CATEGORIES
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
@ -16,10 +16,10 @@ from video_compression_model import VideoCompressionModel
|
||||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||||
CRF = 51
|
CRF = 51
|
||||||
SPEED = PRESET_SPEED_CATEGORIES.index("veryslow")
|
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
||||||
|
|
||||||
# Load the trained model
|
# Load the trained model
|
||||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
|
||||||
|
|
||||||
# Load the uncompressed video
|
# Load the uncompressed video
|
||||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
||||||
|
@ -36,8 +36,8 @@ def load_frame_from_video(video_file, frame_num):
|
||||||
|
|
||||||
def predict_frame(uncompressed_frame):
|
def predict_frame(uncompressed_frame):
|
||||||
|
|
||||||
display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||||
cv2.imshow("uncomp", uncompressed_frame)
|
#cv2.imshow("uncomp", uncompressed_frame)
|
||||||
|
|
||||||
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
|
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,11 @@
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
|
from tensorflow.keras import backend as K
|
||||||
|
|
||||||
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
|
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
|
||||||
|
|
||||||
|
@ -38,6 +43,10 @@ def extract_histogram_features(frame, bins=64):
|
||||||
|
|
||||||
return np.array(feature_vector)
|
return np.array(feature_vector)
|
||||||
|
|
||||||
|
def psnr(y_true, y_pred):
|
||||||
|
max_pixel = 1.0
|
||||||
|
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_frame(frame, crf, speed):
|
def preprocess_frame(frame, crf, speed):
|
||||||
# Check frame dimensions and resize if necessary
|
# Check frame dimensions and resize if necessary
|
||||||
|
|
|
@ -46,5 +46,17 @@
|
||||||
"original_video_file": "Scene9.mkv",
|
"original_video_file": "Scene9.mkv",
|
||||||
"crf": 15,
|
"crf": 15,
|
||||||
"preset_speed": "slow"
|
"preset_speed": "slow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"compressed_video_file": "Scene10_x264_crf-23_preset-ultrafast.mkv",
|
||||||
|
"original_video_file": "Scene10.mkv",
|
||||||
|
"crf": 23,
|
||||||
|
"preset_speed": "ultrafast"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"compressed_video_file": "Scene11_x264_crf-42_preset-medium.mkv",
|
||||||
|
"original_video_file": "Scene11.mkv",
|
||||||
|
"crf": 42,
|
||||||
|
"preset_speed": "medium"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,8 +1,16 @@
|
||||||
|
# train_model.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODO:
|
||||||
|
- Add more different videos with different parateters into the training set.
|
||||||
|
- Add different scenes with the same parameters
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import cv2
|
|
||||||
import numpy as np
|
from featureExtraction import psnr
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
|
@ -16,10 +24,12 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||||
# Constants
|
# Constants
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 100
|
EPOCHS = 100
|
||||||
LEARNING_RATE = 0.000001
|
LEARNING_RATE = 0.001
|
||||||
|
DECAY_STEPS = 40
|
||||||
|
DECAY_RATE = 0.9
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 5
|
||||||
|
|
||||||
def save_model(model):
|
def save_model(model):
|
||||||
try:
|
try:
|
||||||
|
@ -58,7 +68,7 @@ def load_video_metadata(list_path):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
|
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
|
||||||
# Argument parsing
|
# Argument parsing
|
||||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||||
|
@ -66,12 +76,16 @@ def main():
|
||||||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||||
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
|
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
|
||||||
|
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
|
||||||
|
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
BATCH_SIZE = args.batch_size
|
BATCH_SIZE = args.batch_size
|
||||||
EPOCHS = args.epochs
|
EPOCHS = args.epochs
|
||||||
LEARNING_RATE = args.learning_rate
|
LEARNING_RATE = args.learning_rate
|
||||||
MAX_FRAMES = args.max_frames
|
MAX_FRAMES = args.max_frames
|
||||||
|
DECAY_RATE = args.decay_rate
|
||||||
|
DECAY_STEPS = args.decay_steps
|
||||||
|
|
||||||
# Display training configuration
|
# Display training configuration
|
||||||
LOGGER.info("Starting the training with the given configuration.")
|
LOGGER.info("Starting the training with the given configuration.")
|
||||||
|
@ -96,11 +110,20 @@ def main():
|
||||||
model = tf.keras.models.load_model(args.continue_training)
|
model = tf.keras.models.load_model(args.continue_training)
|
||||||
else:
|
else:
|
||||||
model = VideoCompressionModel()
|
model = VideoCompressionModel()
|
||||||
|
|
||||||
|
|
||||||
|
# Define exponential decay schedule
|
||||||
|
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
||||||
|
initial_learning_rate=LEARNING_RATE,
|
||||||
|
decay_steps=DECAY_STEPS,
|
||||||
|
decay_rate=DECAY_RATE,
|
||||||
|
staircase=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Set optimizer and compile the model
|
# Set optimizer and compile the model
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
|
||||||
|
|
||||||
# Define checkpoints and early stopping
|
# Define checkpoints and early stopping
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
|
Reference in a new issue