Update
This commit is contained in:
parent
54fa90247a
commit
15d8e57da5
4 changed files with 56 additions and 12 deletions
|
@ -2,7 +2,7 @@
|
|||
|
||||
import os
|
||||
|
||||
from featureExtraction import preprocess_frame
|
||||
from featureExtraction import preprocess_frame, psnr
|
||||
from globalVars import PRESET_SPEED_CATEGORIES
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
@ -16,10 +16,10 @@ from video_compression_model import VideoCompressionModel
|
|||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||
CRF = 51
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index("veryslow")
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
||||
|
||||
# Load the trained model
|
||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
|
||||
|
||||
# Load the uncompressed video
|
||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
||||
|
@ -36,8 +36,8 @@ def load_frame_from_video(video_file, frame_num):
|
|||
|
||||
def predict_frame(uncompressed_frame):
|
||||
|
||||
display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
cv2.imshow("uncomp", uncompressed_frame)
|
||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
#cv2.imshow("uncomp", uncompressed_frame)
|
||||
|
||||
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
|
||||
|
||||
|
|
|
@ -2,6 +2,11 @@
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
from tensorflow.keras import backend as K
|
||||
|
||||
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
|
||||
|
||||
|
@ -38,6 +43,10 @@ def extract_histogram_features(frame, bins=64):
|
|||
|
||||
return np.array(feature_vector)
|
||||
|
||||
def psnr(y_true, y_pred):
|
||||
max_pixel = 1.0
|
||||
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
||||
|
||||
|
||||
def preprocess_frame(frame, crf, speed):
|
||||
# Check frame dimensions and resize if necessary
|
||||
|
|
|
@ -46,5 +46,17 @@
|
|||
"original_video_file": "Scene9.mkv",
|
||||
"crf": 15,
|
||||
"preset_speed": "slow"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene10_x264_crf-23_preset-ultrafast.mkv",
|
||||
"original_video_file": "Scene10.mkv",
|
||||
"crf": 23,
|
||||
"preset_speed": "ultrafast"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene11_x264_crf-42_preset-medium.mkv",
|
||||
"original_video_file": "Scene11.mkv",
|
||||
"crf": 42,
|
||||
"preset_speed": "medium"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -1,8 +1,16 @@
|
|||
# train_model.py
|
||||
|
||||
"""
|
||||
TODO:
|
||||
- Add more different videos with different parateters into the training set.
|
||||
- Add different scenes with the same parameters
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from featureExtraction import psnr
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
|
@ -16,10 +24,12 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
|||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.000001
|
||||
LEARNING_RATE = 0.001
|
||||
DECAY_STEPS = 40
|
||||
DECAY_RATE = 0.9
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
EARLY_STOP = 5
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
|
@ -58,7 +68,7 @@ def load_video_metadata(list_path):
|
|||
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||
|
@ -66,12 +76,16 @@ def main():
|
|||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Batch size for training.')
|
||||
parser.add_argument('-ds', '--decay_steps', type=int, default=DECAY_STEPS, help='Decay size for training.')
|
||||
parser.add_argument('-dr', '--decay_rate', type=float, default=DECAY_RATE, help='Decay rate for training.')
|
||||
args = parser.parse_args()
|
||||
|
||||
BATCH_SIZE = args.batch_size
|
||||
EPOCHS = args.epochs
|
||||
LEARNING_RATE = args.learning_rate
|
||||
MAX_FRAMES = args.max_frames
|
||||
DECAY_RATE = args.decay_rate
|
||||
DECAY_STEPS = args.decay_steps
|
||||
|
||||
# Display training configuration
|
||||
LOGGER.info("Starting the training with the given configuration.")
|
||||
|
@ -96,11 +110,20 @@ def main():
|
|||
model = tf.keras.models.load_model(args.continue_training)
|
||||
else:
|
||||
model = VideoCompressionModel()
|
||||
|
||||
|
||||
# Define exponential decay schedule
|
||||
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
|
||||
initial_learning_rate=LEARNING_RATE,
|
||||
decay_steps=DECAY_STEPS,
|
||||
decay_rate=DECAY_RATE,
|
||||
staircase=False
|
||||
)
|
||||
|
||||
|
||||
# Set optimizer and compile the model
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr])
|
||||
|
||||
# Define checkpoints and early stopping
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
|
|
Reference in a new issue