Improved model
This commit is contained in:
parent
9167ff27d4
commit
60c6c97071
8 changed files with 327 additions and 112 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -10,3 +10,7 @@
|
||||||
!DeepEncode.py
|
!DeepEncode.py
|
||||||
!train_model.py
|
!train_model.py
|
||||||
!video_compression_model.py
|
!video_compression_model.py
|
||||||
|
!global_train.py
|
||||||
|
!log.py
|
||||||
|
!test_data/training.json
|
||||||
|
!test_data/validation.json
|
|
@ -41,8 +41,7 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value):
|
||||||
#cv2.waitKey(10)
|
#cv2.waitKey(10)
|
||||||
|
|
||||||
compressed_frame = model.predict({
|
compressed_frame = model.predict({
|
||||||
"uncompressed_frame": uncompressed_frame,
|
"compressed_frame": uncompressed_frame,
|
||||||
"compressed_frame": uncompressed_frame,
|
|
||||||
"crf": crf_array,
|
"crf": crf_array,
|
||||||
"preset_speed": preset_speed_array
|
"preset_speed": preset_speed_array
|
||||||
})
|
})
|
||||||
|
|
3
global_train.py
Normal file
3
global_train.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import log
|
||||||
|
|
||||||
|
LOGGER = log.Logger(level="INFO", logfile="training.log", reset_logfile=True)
|
66
log.py
Normal file
66
log.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import datetime
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
class TerminalColors:
|
||||||
|
HEADER = '\033[95m'
|
||||||
|
OKBLUE = '\033[94m'
|
||||||
|
OKCYAN = '\033[96m'
|
||||||
|
OKGREEN = '\033[92m'
|
||||||
|
WARNING = '\033[93m'
|
||||||
|
FAIL = '\033[91m'
|
||||||
|
ENDC = '\033[0m'
|
||||||
|
BOLD = '\033[1m'
|
||||||
|
UNDERLINE = '\033[4m'
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
LEVELS = {"TRACE": 0, "INFO": 1, "DEBUG": 2, "WARN": 3, "ERROR": 4}
|
||||||
|
COLORS = {"TRACE": TerminalColors.HEADER, "INFO": TerminalColors.OKCYAN, "DEBUG": TerminalColors.OKGREEN,
|
||||||
|
"WARN": TerminalColors.WARNING, "ERROR": TerminalColors.FAIL}
|
||||||
|
|
||||||
|
def __init__(self, level="INFO", logfile=None, log_format="{timestamp} {level} {message}", reset_logfile=False):
|
||||||
|
self.level = level
|
||||||
|
self.logfile = logfile
|
||||||
|
self.log_format = log_format
|
||||||
|
|
||||||
|
if reset_logfile and logfile:
|
||||||
|
with open(logfile, 'w') as file:
|
||||||
|
file.truncate(0) # This will clear the content of the file
|
||||||
|
|
||||||
|
def _get_caller_info(self):
|
||||||
|
frame = inspect.stack()[3]
|
||||||
|
filename = frame.filename.split('/')[-1] # Extracts the last part after the final '/'
|
||||||
|
line_number = frame.lineno
|
||||||
|
return filename, line_number
|
||||||
|
|
||||||
|
def _log_to_file(self, message):
|
||||||
|
if self.logfile:
|
||||||
|
with open(self.logfile, 'a') as file:
|
||||||
|
file.write(message + '\n')
|
||||||
|
|
||||||
|
def _print_log(self, level_name, *args):
|
||||||
|
if Logger.LEVELS[level_name] >= Logger.LEVELS[self.level]:
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
message = " ".join(map(str, args))
|
||||||
|
|
||||||
|
if level_name in ["TRACE", "DEBUG"]:
|
||||||
|
filename, line_number = self._get_caller_info()
|
||||||
|
message = f"({filename}:{line_number}) {message}"
|
||||||
|
|
||||||
|
log_message = self.log_format.format(timestamp=timestamp, level=level_name, message=message)
|
||||||
|
print(f"{Logger.COLORS[level_name]}{log_message}{TerminalColors.ENDC}")
|
||||||
|
self._log_to_file(log_message)
|
||||||
|
|
||||||
|
def trace(self, *args):
|
||||||
|
self._print_log("TRACE", *args)
|
||||||
|
|
||||||
|
def info(self, *args):
|
||||||
|
self._print_log("INFO", *args)
|
||||||
|
|
||||||
|
def warn(self, *args):
|
||||||
|
self._print_log("WARN", *args)
|
||||||
|
|
||||||
|
def debug(self, *args):
|
||||||
|
self._print_log("DEBUG", *args)
|
||||||
|
|
||||||
|
def error(self, *args):
|
||||||
|
self._print_log("ERROR", *args)
|
74
test_data/training.json
Normal file
74
test_data/training.json
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-ultrafast.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "ultrafast"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-16_preset-veryslow.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 16,
|
||||||
|
"preset_speed": "veryslow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-18_preset-ultrafast.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 18,
|
||||||
|
"preset_speed": "ultrafast"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-18_preset-veryslow.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 18,
|
||||||
|
"preset_speed": "veryslow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-50_preset-veryslow.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 50,
|
||||||
|
"preset_speed": "veryslow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-fast.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "fast"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-faster.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "faster"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-medium.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "medium"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-slow.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "slow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-slower.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "slower"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-superfast.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "superfast"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-51_preset-veryfast.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "veryfast"
|
||||||
|
}
|
||||||
|
]
|
8
test_data/validation.json
Normal file
8
test_data/validation.json
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"video_file": "x264_crf-16_preset-veryslow.mkv",
|
||||||
|
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 16,
|
||||||
|
"preset_speed": "veryslow"
|
||||||
|
}
|
||||||
|
]
|
226
train_model.py
226
train_model.py
|
@ -1,4 +1,9 @@
|
||||||
|
# train_model.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
@ -7,82 +12,122 @@ import tensorflow as tf
|
||||||
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
||||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
|
from global_train import LOGGER
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
EPOCHS = 100
|
EPOCHS = 100
|
||||||
LEARNING_RATE = 0.000001
|
LEARNING_RATE = 0.000001
|
||||||
TRAIN_SAMPLES = 500
|
TRAIN_SAMPLES = 50
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
CONTINUE_TRAINING = None
|
EARLY_STOP = 10
|
||||||
|
|
||||||
def load_list(list_path):
|
def load_video_metadata(list_path):
|
||||||
with open(list_path, "r") as json_file:
|
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||||
video_details_list = json.load(json_file)
|
try:
|
||||||
return video_details_list
|
with open(list_path, "r") as json_file:
|
||||||
|
file = json.load(json_file)
|
||||||
|
LOGGER.trace(f"load_video_metadata returning: {file}")
|
||||||
|
return file
|
||||||
|
except FileNotFoundError:
|
||||||
|
LOGGER.error(f"Metadata file {list_path} not found.")
|
||||||
|
raise
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
||||||
|
raise
|
||||||
|
|
||||||
def load_video_from_list(list_path, samples = TRAIN_SAMPLES):
|
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
||||||
details_list = load_list(list_path)
|
"""
|
||||||
all_details = []
|
Load video samples from the metadata list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- list_path (str): Path to the metadata JSON file.
|
||||||
|
- samples (int): Number of total samples to be extracted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- list: Extracted video samples.
|
||||||
|
"""
|
||||||
|
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})" )
|
||||||
|
|
||||||
|
details_list = load_video_metadata(list_path)
|
||||||
|
all_samples = []
|
||||||
num_videos = len(details_list)
|
num_videos = len(details_list)
|
||||||
frames_per_video = int(samples / num_videos)
|
frames_per_video = int(samples / num_videos)
|
||||||
|
|
||||||
print(f"Loading {frames_per_video} frames across {num_videos} videos")
|
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
|
||||||
|
|
||||||
for video_details in details_list:
|
for video_details in details_list:
|
||||||
VIDEO_FILE = video_details["video_file"]
|
video_file = video_details["video_file"]
|
||||||
UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"]
|
uncompressed_video_file = video_details["uncompressed_video_file"]
|
||||||
CRF = video_details['crf'] / 63.0
|
crf = video_details['crf'] / 63.0
|
||||||
PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
||||||
video_details['preset_speed'] = PRESET_SPEED
|
video_details['preset_speed'] = preset_speed
|
||||||
|
|
||||||
|
compressed_frames, uncompressed_frames = [], []
|
||||||
|
|
||||||
frames = []
|
try:
|
||||||
frames_compressed = []
|
cap = cv2.VideoCapture(os.path.join("test_data/", video_file))
|
||||||
|
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", uncompressed_video_file))
|
||||||
cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE))
|
|
||||||
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE))
|
if not cap.isOpened() or not cap_uncompressed.isOpened():
|
||||||
|
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}")
|
||||||
for _ in range(frames_per_video):
|
|
||||||
ret, frame_compressed = cap.read()
|
|
||||||
ret_uncompressed, frame = cap_uncompressed.read()
|
|
||||||
|
|
||||||
if not ret or not ret_uncompressed:
|
|
||||||
continue
|
|
||||||
|
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
frames.append(preprocess(frame))
|
|
||||||
frames_compressed.append(preprocess(frame_compressed))
|
|
||||||
|
|
||||||
for uncompressed_frame, compressed_frame in zip(frames, frames_compressed):
|
|
||||||
all_details.append({
|
|
||||||
"frame": uncompressed_frame,
|
|
||||||
"compressed_frame": compressed_frame,
|
|
||||||
"crf": CRF,
|
|
||||||
"preset_speed": PRESET_SPEED,
|
|
||||||
"video_file": VIDEO_FILE
|
|
||||||
})
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
cap_uncompressed.release()
|
|
||||||
|
|
||||||
return all_details
|
for _ in range(frames_per_video):
|
||||||
|
ret, frame_compressed = cap.read()
|
||||||
|
ret_uncompressed, frame = cap_uncompressed.read()
|
||||||
|
|
||||||
def preprocess(frame):
|
if not ret or not ret_uncompressed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
uncompressed_frames.append(normalize(frame))
|
||||||
|
compressed_frames.append(normalize(frame_compressed))
|
||||||
|
|
||||||
|
all_samples.extend({
|
||||||
|
"frame": frame,
|
||||||
|
"compressed_frame": frame_compressed,
|
||||||
|
"crf": crf,
|
||||||
|
"preset_speed": preset_speed,
|
||||||
|
"video_file": video_file
|
||||||
|
} for frame, frame_compressed in zip(uncompressed_frames, compressed_frames))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error during video sample loading: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
cap_uncompressed.release()
|
||||||
|
|
||||||
|
return all_samples
|
||||||
|
|
||||||
|
def normalize(frame):
|
||||||
|
"""
|
||||||
|
Normalize pixel values of the frame to range [0, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- frame (ndarray): Image frame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ndarray: Normalized frame.
|
||||||
|
"""
|
||||||
|
LOGGER.trace(f"Normalizing frame")
|
||||||
return frame / 255.0
|
return frame / 255.0
|
||||||
|
|
||||||
def save_model(model):
|
def save_model(model):
|
||||||
os.makedirs("models", exist_ok=True)
|
try:
|
||||||
model.save(MODEL_SAVE_FILE, save_format='tf')
|
LOGGER.debug("Attempting to save the model.")
|
||||||
print("Model saved successfully!")
|
os.makedirs("models", exist_ok=True)
|
||||||
|
model.save(MODEL_SAVE_FILE, save_format='tf')
|
||||||
|
LOGGER.info("Model saved successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Error saving the model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING
|
|
||||||
|
|
||||||
# Argument parsing
|
# Argument parsing
|
||||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||||
|
@ -92,37 +137,35 @@ def main():
|
||||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Use the parsed arguments in your script
|
|
||||||
BATCH_SIZE = args.batch_size
|
|
||||||
EPOCHS = args.epochs
|
|
||||||
TRAIN_SAMPLES = args.training_samples
|
|
||||||
LEARNING_RATE = args.learning_rate
|
|
||||||
CONTINUE_TRAINING = args.continue_training
|
|
||||||
|
|
||||||
print("Training configuration:")
|
|
||||||
print(f"Batch size: {BATCH_SIZE}")
|
|
||||||
print(f"Epochs: {EPOCHS}")
|
|
||||||
print(f"Training samples: {TRAIN_SAMPLES}")
|
|
||||||
print(f"Learning rate: {LEARNING_RATE}")
|
|
||||||
print(f"Continue training from: {CONTINUE_TRAINING}")
|
|
||||||
|
|
||||||
all_video_details_train = load_video_from_list("test_data/training.json")
|
|
||||||
all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2)
|
|
||||||
|
|
||||||
train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE)
|
# Display training configuration
|
||||||
val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE)
|
LOGGER.info("Starting the training with the given configuration.")
|
||||||
|
LOGGER.info("Training configuration:")
|
||||||
if CONTINUE_TRAINING:
|
LOGGER.info(f"Batch size: {args.batch_size}")
|
||||||
print("loading model:", CONTINUE_TRAINING)
|
LOGGER.info(f"Epochs: {args.epochs}")
|
||||||
model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file
|
LOGGER.info(f"Training samples: {args.training_samples}")
|
||||||
|
LOGGER.info(f"Learning rate: {args.learning_rate}")
|
||||||
|
LOGGER.info(f"Continue training from: {args.continue_training}")
|
||||||
|
|
||||||
|
# Load training and validation samples
|
||||||
|
LOGGER.debug("Loading training and validation samples.")
|
||||||
|
training_samples = load_video_samples("test_data/training.json")
|
||||||
|
validation_samples = load_video_samples("test_data/validation.json", args.training_samples // 2)
|
||||||
|
|
||||||
|
train_generator = VideoDataGenerator(training_samples, args.batch_size)
|
||||||
|
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
|
||||||
|
|
||||||
|
# Load or initialize model
|
||||||
|
if args.continue_training:
|
||||||
|
model = tf.keras.models.load_model(args.continue_training)
|
||||||
else:
|
else:
|
||||||
model = VideoCompressionModel()
|
model = VideoCompressionModel()
|
||||||
|
|
||||||
# Define the optimizer with a specific learning rate
|
# Set optimizer and compile the model
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
|
||||||
|
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
||||||
os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True)
|
|
||||||
|
# Define checkpoints and early stopping
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
||||||
save_weights_only=False,
|
save_weights_only=False,
|
||||||
|
@ -130,24 +173,25 @@ def main():
|
||||||
verbose=1,
|
verbose=1,
|
||||||
save_format="tf"
|
save_format="tf"
|
||||||
)
|
)
|
||||||
|
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||||
|
|
||||||
#tf.config.run_functions_eagerly(True)
|
# Train the model
|
||||||
|
LOGGER.info("Starting model training.")
|
||||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
|
||||||
early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True)
|
|
||||||
|
|
||||||
print("\nTraining the model...")
|
|
||||||
model.fit(
|
model.fit(
|
||||||
train_generator,
|
train_generator,
|
||||||
steps_per_epoch=len(train_generator),
|
steps_per_epoch=len(train_generator),
|
||||||
epochs=EPOCHS,
|
epochs=args.epochs,
|
||||||
validation_data=val_generator,
|
validation_data=val_generator,
|
||||||
validation_steps=len(val_generator),
|
validation_steps=len(val_generator),
|
||||||
callbacks=[early_stop, checkpoint_callback]
|
callbacks=[early_stop, checkpoint_callback]
|
||||||
)
|
)
|
||||||
print("\nTraining completed!")
|
LOGGER.info("Model training completed.")
|
||||||
|
|
||||||
save_model(model)
|
save_model(model)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
try:
|
||||||
|
main()
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Unexpected error during training: {e}")
|
||||||
|
raise
|
||||||
|
|
|
@ -3,12 +3,15 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from global_train import LOGGER
|
||||||
|
|
||||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||||
NUM_CHANNELS = 3
|
NUM_CHANNELS = 3
|
||||||
|
|
||||||
class VideoDataGenerator(tf.keras.utils.Sequence):
|
class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||||
def __init__(self, video_details_list, batch_size):
|
def __init__(self, video_details_list, batch_size):
|
||||||
|
LOGGER.debug("Initializing VideoDataGenerator with batch size: {}".format(batch_size))
|
||||||
self.video_details_list = video_details_list
|
self.video_details_list = video_details_list
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
@ -16,25 +19,34 @@ class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||||
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
start_idx = idx * self.batch_size
|
try:
|
||||||
end_idx = (idx + 1) * self.batch_size
|
start_idx = idx * self.batch_size
|
||||||
|
end_idx = (idx + 1) * self.batch_size
|
||||||
|
|
||||||
|
batch_data = self.video_details_list[start_idx:end_idx]
|
||||||
|
|
||||||
|
x1 = np.array([item["frame"] for item in batch_data])
|
||||||
|
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
||||||
|
x3 = np.array([item["crf"] for item in batch_data])
|
||||||
|
x4 = np.array([item["preset_speed"] for item in batch_data])
|
||||||
|
|
||||||
|
y = x2
|
||||||
|
|
||||||
|
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||||
|
return inputs, y
|
||||||
|
|
||||||
batch_data = self.video_details_list[start_idx:end_idx]
|
except IndexError:
|
||||||
|
LOGGER.error(f"Index {idx} out of bounds in VideoDataGenerator.")
|
||||||
x1 = np.array([item["frame"] for item in batch_data])
|
raise
|
||||||
x2 = np.array([item["compressed_frame"] for item in batch_data])
|
except Exception as e:
|
||||||
x3 = np.array([item["crf"] for item in batch_data])
|
LOGGER.error(f"Unexpected error in VideoDataGenerator: {e}")
|
||||||
x4 = np.array([item["preset_speed"] for item in batch_data])
|
raise
|
||||||
|
|
||||||
y = x2
|
|
||||||
|
|
||||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
|
||||||
return inputs, y
|
|
||||||
|
|
||||||
|
|
||||||
class VideoCompressionModel(tf.keras.Model):
|
class VideoCompressionModel(tf.keras.Model):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(VideoCompressionModel, self).__init__()
|
super(VideoCompressionModel, self).__init__()
|
||||||
|
LOGGER.debug("Initializing VideoCompressionModel.")
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
|
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
|
||||||
|
@ -68,14 +80,19 @@ class VideoCompressionModel(tf.keras.Model):
|
||||||
])
|
])
|
||||||
|
|
||||||
def model_summary(self):
|
def model_summary(self):
|
||||||
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
|
try:
|
||||||
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
|
LOGGER.info("Generating model summary.")
|
||||||
x3 = tf.keras.Input(shape=(1,), name='crf')
|
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
|
||||||
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
|
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
|
||||||
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
|
x3 = tf.keras.Input(shape=(1,), name='crf')
|
||||||
|
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
|
||||||
|
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.error(f"Unexpected error during model summary generation: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
|
LOGGER.trace("Calling VideoCompressionModel.")
|
||||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||||
|
|
||||||
# Convert frames to float32
|
# Convert frames to float32
|
||||||
|
|
Reference in a new issue