updated
This commit is contained in:
parent
ed5eb91578
commit
9ae5921e2b
3 changed files with 153 additions and 210 deletions
|
@ -4,5 +4,23 @@
|
|||
"original_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv",
|
||||
"crf": 51,
|
||||
"preset_speed": "veryslow"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene3_x264_crf-51_preset-ultrafast.mkv",
|
||||
"original_video_file": "Scene3.mkv",
|
||||
"crf": 51,
|
||||
"preset_speed": "ultrafast"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene4_x264_crf-51_preset-veryslow.mkv",
|
||||
"original_video_file": "Scene4.mkv",
|
||||
"crf": 51,
|
||||
"preset_speed": "veryslow"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene5_x264_crf-51_preset-veryslow.mkv",
|
||||
"original_video_file": "Scene5.mkv",
|
||||
"crf": 51,
|
||||
"preset_speed": "veryslow"
|
||||
}
|
||||
]
|
||||
|
|
182
train_model.py
182
train_model.py
|
@ -1,29 +1,77 @@
|
|||
# train_model.py
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from train_model_V2 import VideoCompressionModel
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
import tensorflow as tf
|
||||
from video_compression_model import WIDTH, HEIGHT, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
from global_train import LOGGER
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 100
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 5
|
||||
LEARNING_RATE = 0.01
|
||||
TRAIN_SAMPLES = 100
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
|
||||
NUM_CHANNELS = 3
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
LOGGER.debug("Attempting to save the model.")
|
||||
os.makedirs("models", exist_ok=True)
|
||||
model.save(MODEL_SAVE_FILE, save_format='tf')
|
||||
LOGGER.info("Model saved successfully!")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error saving the model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def extract_edge_features(frame):
|
||||
"""
|
||||
Extract edge features using Canny edge detection.
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
|
||||
Returns:
|
||||
- ndarray: Edge feature map.
|
||||
"""
|
||||
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
|
||||
return edges.astype(np.float32) / 255.0
|
||||
|
||||
def extract_histogram_features(frame, bins=64):
|
||||
"""
|
||||
Extract histogram features from a frame.
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
- bins (int): Number of bins for the histogram.
|
||||
|
||||
Returns:
|
||||
- ndarray: Normalized histogram feature vector.
|
||||
"""
|
||||
histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255])
|
||||
return histogram.astype(np.float32) / frame.size
|
||||
|
||||
|
||||
def load_video_metadata(list_path):
|
||||
"""
|
||||
Load video metadata from a JSON file.
|
||||
|
||||
Args:
|
||||
- json_path (str): Path to the JSON file containing video metadata.
|
||||
|
||||
Returns:
|
||||
- list: List of dictionaries, each containing video details.
|
||||
"""
|
||||
|
||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||
try:
|
||||
with open(list_path, "r") as json_file:
|
||||
|
@ -36,42 +84,47 @@ def load_video_metadata(list_path):
|
|||
except json.JSONDecodeError:
|
||||
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
||||
raise
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
while True:
|
||||
for video_details in videos:
|
||||
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
||||
LOGGER.trace(f"Entering: load_video_samples({list_path}, {samples})")
|
||||
details_list = load_video_metadata(list_path)
|
||||
all_samples = []
|
||||
num_videos = len(details_list)
|
||||
frames_per_video = math.ceil(samples / num_videos)
|
||||
LOGGER.info(f"Loading {frames_per_video} frames from {num_videos} videos")
|
||||
for video_details in details_list:
|
||||
compressed_video_file = video_details["compressed_video_file"]
|
||||
original_video_file = video_details["original_video_file"]
|
||||
crf = video_details['crf'] / 51
|
||||
preset_speed = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
||||
video_details['preset_speed'] = preset_speed
|
||||
feature_batch = []
|
||||
compressed_frame_batch = []
|
||||
|
||||
# Store video details without loading frames
|
||||
all_samples.extend({
|
||||
"frames_per_video": frames_per_video,
|
||||
"crf": crf,
|
||||
"preset_speed": preset_speed,
|
||||
"compressed_video_file": os.path.join(os.path.dirname(list_path), compressed_video_file),
|
||||
"original_video_file": os.path.join(os.path.dirname(list_path), original_video_file)
|
||||
} for _ in range(frames_per_video))
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
return all_samples
|
||||
# Check frame dimensions and resize if necessary
|
||||
if frame.shape[:2] != (HEIGHT, WIDTH):
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
# Extract features
|
||||
edge_feature = extract_edge_features(frame)
|
||||
histogram_feature = extract_histogram_features(frame)
|
||||
histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape
|
||||
combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1)
|
||||
|
||||
compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1]
|
||||
|
||||
feature_batch.append(combined_feature)
|
||||
compressed_frame_batch.append(compressed_frame)
|
||||
|
||||
if len(feature_batch) == batch_size:
|
||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||
feature_batch = []
|
||||
compressed_frame_batch = []
|
||||
|
||||
cap.release()
|
||||
|
||||
# If there are frames left that don't fill a whole batch, send them anyway
|
||||
if len(feature_batch) > 0:
|
||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
LOGGER.debug("Attempting to save the model.")
|
||||
os.makedirs("models", exist_ok=True)
|
||||
model.save(MODEL_SAVE_FILE, save_format='tf')
|
||||
LOGGER.info("Model saved successfully!")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error saving the model: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
|
||||
|
@ -100,25 +153,22 @@ def main():
|
|||
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
||||
|
||||
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
||||
LOGGER.trace("Hello, World!")
|
||||
|
||||
# Load training and validation samples
|
||||
LOGGER.debug("Loading training and validation samples.")
|
||||
training_samples = load_video_samples("test_data/training/training.json", TRAIN_SAMPLES)
|
||||
validation_samples = load_video_samples("test_data/validation/validation.json", math.ceil(TRAIN_SAMPLES / 10))
|
||||
# Load all video metadata
|
||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||
|
||||
train_generator = VideoDataGenerator(training_samples, BATCH_SIZE)
|
||||
val_generator = VideoDataGenerator(validation_samples, BATCH_SIZE)
|
||||
# Split into training and validation
|
||||
split_index = int(0.8 * len(all_videos))
|
||||
training_videos = all_videos[:split_index]
|
||||
validation_videos = all_videos[split_index:]
|
||||
|
||||
# Load or initialize model
|
||||
if args.continue_training:
|
||||
model = tf.keras.models.load_model(args.continue_training)
|
||||
else:
|
||||
model = VideoCompressionModel()
|
||||
model = VideoCompressionModel()
|
||||
|
||||
# Set optimizer and compile the model
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
||||
|
||||
|
||||
# Define checkpoints and early stopping
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
||||
|
@ -129,23 +179,31 @@ def main():
|
|||
)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Calculate steps per epoch for training and validation
|
||||
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
||||
total_frames_train = average_frames_per_video * len(training_videos)
|
||||
total_frames_validation = average_frames_per_video * len(validation_videos)
|
||||
steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
||||
steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
|
||||
|
||||
# Train the model
|
||||
LOGGER.info("Starting model training.")
|
||||
model.fit(
|
||||
train_generator,
|
||||
steps_per_epoch=len(train_generator),
|
||||
epochs=EPOCHS,
|
||||
validation_data=val_generator,
|
||||
validation_steps=len(val_generator),
|
||||
data_generator(training_videos, BATCH_SIZE),
|
||||
epochs=EPOCHS,
|
||||
steps_per_epoch=steps_per_epoch_train,
|
||||
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
|
||||
validation_steps=steps_per_epoch_validation, # Add validation steps here
|
||||
callbacks=[early_stop, checkpoint_callback]
|
||||
)
|
||||
LOGGER.info("Model training completed.")
|
||||
|
||||
|
||||
save_model(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error during training: {e}")
|
||||
raise
|
||||
raise
|
|
@ -9,7 +9,7 @@ from global_train import LOGGER
|
|||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_CHANNELS = 3
|
||||
WIDTH = 638
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
|
||||
#from tensorflow.keras.mixed_precision import Policy
|
||||
|
@ -19,164 +19,31 @@ HEIGHT = 360
|
|||
|
||||
|
||||
|
||||
def normalize(frame):
|
||||
"""
|
||||
Normalize pixel values of the frame to range [0, 1].
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
|
||||
Returns:
|
||||
- ndarray: Normalized frame.
|
||||
"""
|
||||
LOGGER.trace(f"Normalizing frame")
|
||||
return frame / 255.0
|
||||
|
||||
class VideoDataGenerator(tf.keras.utils.Sequence):
|
||||
def __init__(self, video_details_list, batch_size):
|
||||
LOGGER.debug("Initializing VideoDataGenerator with batch size: {}".format(batch_size))
|
||||
self.video_details_list = video_details_list
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return int(np.ceil(len(self.video_details_list) / float(self.batch_size)))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start_idx = idx * self.batch_size
|
||||
end_idx = (idx + 1) * self.batch_size
|
||||
batch_data = self.video_details_list[start_idx:end_idx]
|
||||
|
||||
# Determine the number of videos and frames per video
|
||||
num_videos = len(batch_data)
|
||||
frames_per_video = batch_data[0]['frames_per_video'] # Assuming all videos have the same number of frames
|
||||
|
||||
# Pre-allocate arrays for the batch data
|
||||
x1 = np.empty((num_videos * frames_per_video, HEIGHT, WIDTH, NUM_CHANNELS))
|
||||
x2 = np.empty_like(x1)
|
||||
x3 = np.empty((num_videos * frames_per_video, 1))
|
||||
x4 = np.empty_like(x3)
|
||||
|
||||
# Iterate over the videos and frames, filling the pre-allocated arrays
|
||||
for i, item in enumerate(batch_data):
|
||||
compressed_video_file = item["compressed_video_file"]
|
||||
original_video_file = item["original_video_file"]
|
||||
crf = item["crf"]
|
||||
preset_speed = item["preset_speed"]
|
||||
|
||||
cap_compressed = cv2.VideoCapture(compressed_video_file)
|
||||
cap_original = cv2.VideoCapture(original_video_file)
|
||||
for j in range(frames_per_video):
|
||||
compressed_ret, compressed_frame = cap_compressed.read()
|
||||
original_ret, original_frame = cap_original.read()
|
||||
if not compressed_ret or not original_ret:
|
||||
continue
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if original_frame.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.info(f"Resizing video: {original_video_file}")
|
||||
original_frame = cv2.resize(original_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
if compressed_frame.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.info(f"Resizing video: {compressed_video_file}")
|
||||
compressed_frame = cv2.resize(compressed_frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
|
||||
original_frame = cv2.cvtColor(original_frame, cv2.COLOR_BGR2RGB)
|
||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Store the processed frames and metadata directly in the pre-allocated arrays
|
||||
x1[i * frames_per_video + j] = normalize(original_frame)
|
||||
x2[i * frames_per_video + j] = normalize(compressed_frame)
|
||||
x3[i * frames_per_video + j] = crf
|
||||
x4[i * frames_per_video + j] = preset_speed
|
||||
|
||||
cap_original.release()
|
||||
cap_compressed.release()
|
||||
|
||||
y = x2
|
||||
inputs = {"uncompressed_frame": x1, "compressed_frame": x2, "crf": x3, "preset_speed": x4}
|
||||
return inputs, y
|
||||
|
||||
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
LOGGER.debug("Initializing VideoCompressionModel.")
|
||||
|
||||
# Inputs
|
||||
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
|
||||
self.preset_speed_input = tf.keras.layers.InputLayer(name='preset_speed', input_shape=(1,))
|
||||
self.uncompressed_frame_input = tf.keras.layers.InputLayer(name='uncompressed_frame', input_shape=(None, None, NUM_CHANNELS))
|
||||
self.compressed_frame_input = tf.keras.layers.InputLayer(name='compressed_frame', input_shape=(None, None, NUM_CHANNELS))
|
||||
|
||||
# Embedding for speed preset and FC layer for CRF and preset speed
|
||||
self.embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16)
|
||||
self.fc = tf.keras.layers.Dense(32, activation='relu')
|
||||
# Add an additional channel for the histogram features
|
||||
input_shape_with_histogram = (HEIGHT, WIDTH, 2) # 1 channel for edges, 1 for histogram
|
||||
|
||||
# Encoder layers
|
||||
self.encoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(None, None, 2 * NUM_CHANNELS + 32)),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.MaxPooling2D((2, 2)),
|
||||
tf.keras.layers.Dropout(0.3)
|
||||
tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram),
|
||||
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
|
||||
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.MaxPooling2D((2, 2), padding='same')
|
||||
])
|
||||
|
||||
# Decoder layers
|
||||
self.decoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
tf.keras.layers.Dropout(0.3),
|
||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
||||
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
LOGGER.trace("Calling VideoCompressionModel.")
|
||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||
|
||||
# Convert frames to float32
|
||||
uncompressed_frame = tf.cast(uncompressed_frame, tf.float16)
|
||||
compressed_frame = tf.cast(compressed_frame, tf.float16)
|
||||
|
||||
# Embedding for preset speed
|
||||
preset_speed_embedded = self.embedding(preset_speed)
|
||||
preset_speed_embedded = tf.keras.layers.Flatten()(preset_speed_embedded)
|
||||
|
||||
# Reshaping CRF to match the shape of preset_speed_embedded
|
||||
crf_expanded = tf.keras.layers.Flatten()(tf.repeat(crf, 16, axis=-1))
|
||||
|
||||
|
||||
# Concatenating the CRF and preset speed information
|
||||
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, preset_speed_embedded])
|
||||
integrated_info = self.fc(integrated_info)
|
||||
|
||||
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
||||
_, height, width, _ = uncompressed_frame.shape
|
||||
current_shape = tf.shape(inputs["uncompressed_frame"])
|
||||
|
||||
height = current_shape[1]
|
||||
width = current_shape[2]
|
||||
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
||||
|
||||
# Merge uncompressed and compressed frames
|
||||
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
||||
|
||||
compressed_representation = self.encoder(frames_merged)
|
||||
reconstructed_frame = self.decoder(compressed_representation)
|
||||
|
||||
return reconstructed_frame
|
||||
|
||||
def model_summary(self):
|
||||
try:
|
||||
LOGGER.info("Generating model summary.")
|
||||
x1 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='uncompressed_frame')
|
||||
x2 = tf.keras.Input(shape=(None, None, NUM_CHANNELS), name='compressed_frame')
|
||||
x3 = tf.keras.Input(shape=(1,), name='crf')
|
||||
x4 = tf.keras.Input(shape=(1,), name='preset_speed')
|
||||
return tf.keras.Model(inputs=[x1, x2, x3, x4], outputs=self.call({'uncompressed_frame': x1, 'compressed_frame': x2, 'crf': x3, 'preset_speed': x4})).summary()
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error during model summary generation: {e}")
|
||||
raise
|
||||
encoded = self.encoder(inputs)
|
||||
decoded = self.decoder(encoded)
|
||||
return decoded
|
||||
|
|
Reference in a new issue