working
This commit is contained in:
parent
f4512bba99
commit
db43239b3d
5 changed files with 311 additions and 197 deletions
181
DeepEncode.py
181
DeepEncode.py
|
@ -1,90 +1,145 @@
|
|||
# DeepEncode.py
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from featureExtraction import combined, preprocess_frame, psnr, scale_crf, scale_speed_preset, ssim
|
||||
from globalVars import PRESET_SPEED_CATEGORIES
|
||||
|
||||
# Set TensorFlow log level before any other imports
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim
|
||||
from globalVars import PRESET_SPEED_CATEGORIES, clear_screen
|
||||
from video_compression_model import VideoCompressionModel, combine_batch
|
||||
|
||||
# Constants
|
||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||
MAX_FRAMES = 200 # Limit the number of frames processed
|
||||
CRF = 51
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
||||
|
||||
# Load the trained model
|
||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined})
|
||||
|
||||
# Load the uncompressed video
|
||||
CRF = 10
|
||||
SPEED = "ultrafast"
|
||||
MODEL_PATH = 'models/model.tf'
|
||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/x264_crf-5_preset-veryslow.mkv'
|
||||
DISPLAY_OUTPUT = False
|
||||
CROP_DIMENSIONS = None
|
||||
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
global COMPRESSED_VIDEO_FILE, MAX_FRAMES, CRF, SPEED, MODEL_PATH, UNCOMPRESSED_VIDEO_FILE, DISPLAY_OUTPUT, CROP_DIMENSIONS
|
||||
parser = argparse.ArgumentParser(description='Deep Encoding of Videos')
|
||||
parser.add_argument('-o', '--compressed_video_file', default=COMPRESSED_VIDEO_FILE, help='Path to the compressed video file')
|
||||
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Maximum number of frames to process')
|
||||
parser.add_argument('-c', '--crf', type=int, default=CRF, help='CRF value for video compression')
|
||||
parser.add_argument('-s', '--speed', default=SPEED, choices=PRESET_SPEED_CATEGORIES, help='Video compression speed category')
|
||||
parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model')
|
||||
parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file')
|
||||
parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen')
|
||||
parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
COMPRESSED_VIDEO_FILE = args.compressed_video_file
|
||||
MAX_FRAMES = args.max_frames
|
||||
CRF = args.crf
|
||||
SPEED = args.speed
|
||||
MODEL_PATH = args.model_path
|
||||
UNCOMPRESSED_VIDEO_FILE = args.uncompressed_video_file
|
||||
DISPLAY_OUTPUT = args.display_output
|
||||
|
||||
if not args.keep_black_bars:
|
||||
CROP_DIMENSIONS = find_crop_dimensions(UNCOMPRESSED_VIDEO_FILE)
|
||||
|
||||
def crop_black_bars(frame):
|
||||
# Convert to grayscale for easier processing
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Threshold the image to make everything below a certain gray value black, and everything else white
|
||||
_, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Find the contours of the white regions
|
||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Find the bounding box that contains all the contours
|
||||
x_min = y_min = float('inf')
|
||||
x_max = y_max = 0
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
x_min = min(x_min, x)
|
||||
y_min = min(y_min, y)
|
||||
x_max = max(x_max, x + w)
|
||||
y_max = max(y_max, y + h)
|
||||
|
||||
return x_min, y_min, x_max, y_max
|
||||
|
||||
def find_crop_dimensions(video_file):
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Error: Unable to find a non-black frame.")
|
||||
cap.release()
|
||||
exit()
|
||||
|
||||
# Check if the frame is entirely black
|
||||
if np.any(frame > 0):
|
||||
x_min, y_min, x_max, y_max = crop_black_bars(frame)
|
||||
cap.release()
|
||||
return x_min, y_min, x_max, y_max
|
||||
|
||||
|
||||
def load_frame_from_video(video_file, frame_num):
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
return None
|
||||
cap.release()
|
||||
|
||||
return frame
|
||||
return frame if ret else None
|
||||
|
||||
def predict_frame(uncompressed_frame):
|
||||
|
||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
#cv2.imshow("uncomp", uncompressed_frame)
|
||||
scaled_crf = scale_crf(CRF)
|
||||
scaled_speed = scale_speed_preset(SPEED)
|
||||
|
||||
def predict_frame(uncompressed_frame, model, crf, speed):
|
||||
scaled_crf = scale_crf(crf)
|
||||
scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed))
|
||||
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
|
||||
|
||||
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
|
||||
|
||||
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
|
||||
|
||||
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
cv2.imshow("comp", compressed_frame)
|
||||
cv2.waitKey(1)
|
||||
|
||||
return compressed_frame
|
||||
compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0]
|
||||
return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
height, width = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
cap.release()
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, fps, (width, height), True)
|
||||
|
||||
if not out.isOpened():
|
||||
print("Error: VideoWriter could not be opened.")
|
||||
exit()
|
||||
|
||||
if MAX_FRAMES != 0 and total_frames > MAX_FRAMES:
|
||||
total_frames = MAX_FRAMES
|
||||
|
||||
for i in range(total_frames):
|
||||
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
||||
compressed_frame = predict_frame(uncompressed_frame)
|
||||
def main():
|
||||
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss})
|
||||
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
|
||||
|
||||
compressed_frame = cv2.resize(compressed_frame, (width, height))
|
||||
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
|
||||
height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
#compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
#compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
out.write(compressed_frame)
|
||||
cap.release()
|
||||
|
||||
#if i % 10 == 0: # Print progress every 10 frames
|
||||
# print(f"Processed {i} / {total_frames} frames")
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, fps, (width, height), True)
|
||||
|
||||
out.release()
|
||||
print("Compression completed.")
|
||||
if not out.isOpened():
|
||||
print("Error: VideoWriter could not be opened.")
|
||||
exit()
|
||||
|
||||
for i in range(total_frames):
|
||||
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
||||
|
||||
if CROP_DIMENSIONS:
|
||||
x_min, y_min, x_max, y_max = CROP_DIMENSIONS
|
||||
uncompressed_frame = uncompressed_frame[y_min:y_max, x_min:x_max]
|
||||
|
||||
compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED)
|
||||
compressed_frame = cv2.resize(compressed_frame, (width, height))
|
||||
|
||||
out.write(compressed_frame)
|
||||
|
||||
if DISPLAY_OUTPUT:
|
||||
cv2.imshow('Compressed Video', compressed_frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
out.release()
|
||||
print("Compression completed.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clear_screen()
|
||||
parse_arguments()
|
||||
main()
|
||||
|
|
|
@ -9,51 +9,21 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
|||
import tensorflow as tf
|
||||
from tensorflow.keras import backend as K
|
||||
|
||||
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
|
||||
from globalVars import HEIGHT, LOGGER, NUM_PRESET_SPEEDS, WIDTH
|
||||
|
||||
def scale_crf(crf):
|
||||
return crf / 51
|
||||
|
||||
|
||||
def scale_speed_preset(speed):
|
||||
return speed / NUM_PRESET_SPEEDS
|
||||
|
||||
|
||||
def extract_edge_features(frame):
|
||||
"""
|
||||
Extract edge features using Canny edge detection.
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
|
||||
Returns:
|
||||
- ndarray: Edge feature map.
|
||||
"""
|
||||
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
|
||||
return edges.astype(np.float32) / 255.0
|
||||
|
||||
def extract_histogram_features(frame, bins=64):
|
||||
"""
|
||||
Extract histogram features from a frame with 3 channels.
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame with shape (height, width, 3).
|
||||
- bins (int): Number of bins for the histogram.
|
||||
|
||||
Returns:
|
||||
- ndarray: Normalized histogram feature vector.
|
||||
"""
|
||||
feature_vector = []
|
||||
for channel in range(3):
|
||||
histogram, _ = np.histogram(frame[:,:,channel].flatten(), bins=bins, range=[0, 255])
|
||||
normalized_histogram = histogram.astype(np.float32) / frame[:,:,channel].size
|
||||
feature_vector.extend(normalized_histogram)
|
||||
|
||||
return np.array(feature_vector)
|
||||
|
||||
|
||||
def psnr(y_true, y_pred):
|
||||
#LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}")
|
||||
max_pixel = 1.0
|
||||
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
||||
mse = K.mean(K.square(y_pred - y_true))
|
||||
return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0)
|
||||
|
||||
|
||||
def ssim(y_true, y_pred):
|
||||
|
@ -64,14 +34,41 @@ def combined(y_true, y_pred):
|
|||
return (psnr(y_true, y_pred) + ssim(y_true, y_pred)) / 2
|
||||
|
||||
|
||||
def preprocess_frame(frame, resize=True):
|
||||
#Preprocesses a single frame, cropping it if needed
|
||||
def combined_loss(y_true, y_pred):
|
||||
return -combined(y_true, y_pred) # The goal is to maximize the combined value
|
||||
|
||||
|
||||
def detect_noise(image, threshold=15):
|
||||
# Convert to grayscale if it's a color image
|
||||
if len(image.shape) == 3:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Compute the standard deviation
|
||||
std_dev = np.std(image)
|
||||
|
||||
# If the standard deviation is higher than a threshold, it might be considered noisy
|
||||
return std_dev > threshold
|
||||
|
||||
|
||||
def frame_difference(frame1, frame2):
|
||||
# Ensure both frames are of the same size and type
|
||||
if frame1.shape != frame2.shape:
|
||||
raise ValueError("Frames must have the same dimensions and number of channels")
|
||||
|
||||
# Calculate the absolute difference between the frames
|
||||
difference = cv2.absdiff(frame1, frame2)
|
||||
|
||||
return difference
|
||||
|
||||
|
||||
def preprocess_frame(frame, resize=True, scale=True):
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if scale:
|
||||
# Scale frame to [0, 1]
|
||||
compressed_frame = frame / 255.0
|
||||
frame = frame / 255.0
|
||||
|
||||
return compressed_frame
|
||||
return frame
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# gobalVars.py
|
||||
|
||||
import json
|
||||
import log
|
||||
import platform
|
||||
import os
|
||||
|
||||
LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True)
|
||||
|
||||
|
@ -9,4 +12,35 @@ NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
|||
NUM_COLOUR_CHANNELS = 3
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
MAX_FRAMES = 0
|
||||
MAX_FRAMES = 0
|
||||
|
||||
def clear_screen():
|
||||
system_name = platform.system()
|
||||
if system_name == "Windows":
|
||||
os.system('cls')
|
||||
else:
|
||||
os.system('clear')
|
||||
|
||||
def load_video_metadata(list_path):
|
||||
"""
|
||||
Load video metadata from a JSON file.
|
||||
|
||||
Args:
|
||||
- json_path (str): Path to the JSON file containing video metadata.
|
||||
|
||||
Returns:
|
||||
- list: List of dictionaries, each containing video details.
|
||||
"""
|
||||
|
||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||
try:
|
||||
with open(list_path, "r") as json_file:
|
||||
file = json.load(json_file)
|
||||
LOGGER.trace(f"load_video_metadata returning: {file}")
|
||||
return file
|
||||
except FileNotFoundError:
|
||||
LOGGER.error(f"Metadata file {list_path} not found.")
|
||||
raise
|
||||
except json.JSONDecodeError:
|
||||
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
||||
raise
|
138
train_model.py
138
train_model.py
|
@ -7,17 +7,25 @@ TODO:
|
|||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import cv2
|
||||
import subprocess
|
||||
import signal
|
||||
|
||||
from featureExtraction import combined, psnr, ssim
|
||||
import numpy as np
|
||||
|
||||
from featureExtraction import combined, combined_loss, psnr, ssim
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
import gc
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.summary import image as tf_image_summary
|
||||
|
||||
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
if gpus:
|
||||
|
@ -30,7 +38,7 @@ if gpus:
|
|||
|
||||
from video_compression_model import VideoCompressionModel, create_dataset
|
||||
|
||||
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_video_metadata
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 25
|
||||
|
@ -41,50 +49,71 @@ DECAY_RATE = 0.9
|
|||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
RANDOM_SEED = 4576
|
||||
MODEL = None
|
||||
LOG_DIR = './logs'
|
||||
|
||||
|
||||
class ImageLoggingCallback(Callback):
|
||||
def __init__(self, validation_dataset, log_dir):
|
||||
super().__init__()
|
||||
self.validation_dataset = validation_dataset
|
||||
self.log_dir = log_dir
|
||||
self.writer = tf.summary.create_file_writer(self.log_dir)
|
||||
|
||||
def convert_images(self, images):
|
||||
converted = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
|
||||
return np.stack(converted, axis=0)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
itter = iter(self.validation_dataset)
|
||||
random_idx = np.random.randint(0, BATCH_SIZE)
|
||||
|
||||
# Loop through the dataset until the chosen index
|
||||
for i, data in enumerate(self.validation_dataset):
|
||||
if i == random_idx:
|
||||
validation_data = data
|
||||
break
|
||||
|
||||
batch_input_images, batch_gt_labels = validation_data
|
||||
|
||||
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
||||
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
reconstructed_frame = MODEL.predict(validation_data[0])
|
||||
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
batch_input_images = self.convert_images(batch_input_images)
|
||||
batch_gt_labels = self.convert_images(batch_gt_labels)
|
||||
reconstructed_frame = self.convert_images(reconstructed_frame)
|
||||
|
||||
# Log images to TensorBoard
|
||||
with self.writer.as_default():
|
||||
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
|
||||
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
|
||||
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
|
||||
self.writer.flush()
|
||||
|
||||
|
||||
|
||||
class GarbageCollectorCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
LOGGER.debug(f"GC")
|
||||
gc.collect()
|
||||
|
||||
def save_model(model):
|
||||
def save_model():
|
||||
try:
|
||||
LOGGER.debug("Attempting to save the model.")
|
||||
os.makedirs("models", exist_ok=True)
|
||||
model.save(MODEL_SAVE_FILE, save_format='tf')
|
||||
MODEL.save(MODEL_SAVE_FILE, save_format='tf')
|
||||
LOGGER.info("Model saved successfully!")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Error saving the model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_video_metadata(list_path):
|
||||
"""
|
||||
Load video metadata from a JSON file.
|
||||
|
||||
Args:
|
||||
- json_path (str): Path to the JSON file containing video metadata.
|
||||
|
||||
Returns:
|
||||
- list: List of dictionaries, each containing video details.
|
||||
"""
|
||||
|
||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||
try:
|
||||
with open(list_path, "r") as json_file:
|
||||
file = json.load(json_file)
|
||||
LOGGER.trace(f"load_video_metadata returning: {file}")
|
||||
return file
|
||||
except FileNotFoundError:
|
||||
LOGGER.error(f"Metadata file {list_path} not found.")
|
||||
raise
|
||||
except json.JSONDecodeError:
|
||||
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE, MODEL
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||
|
@ -119,11 +148,10 @@ def main():
|
|||
# Load all video metadata
|
||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||
|
||||
# Specify the random seed
|
||||
random_seed = 4576 # You can change this to any desired value
|
||||
tf.random.set_seed(RANDOM_SEED)
|
||||
|
||||
# Shuffle the data using the specified seed
|
||||
random.shuffle(all_videos, random.seed(random_seed))
|
||||
random.shuffle(all_videos, random.seed(RANDOM_SEED))
|
||||
|
||||
# Split into training and validation
|
||||
split_index = int(0.6 * len(all_videos))
|
||||
|
@ -136,12 +164,14 @@ def main():
|
|||
|
||||
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
|
||||
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
|
||||
|
||||
tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0, write_graph=True, update_freq='epoch')
|
||||
|
||||
|
||||
if args.continue_training:
|
||||
model = tf.keras.models.load_model(args.continue_training)
|
||||
MODEL = tf.keras.models.load_model(args.continue_training)
|
||||
else:
|
||||
model = VideoCompressionModel()
|
||||
MODEL = VideoCompressionModel()
|
||||
|
||||
|
||||
# Define exponential decay schedule
|
||||
|
@ -149,13 +179,13 @@ def main():
|
|||
initial_learning_rate=LEARNING_RATE,
|
||||
decay_steps=DECAY_STEPS,
|
||||
decay_rate=DECAY_RATE,
|
||||
staircase=False
|
||||
staircase=True
|
||||
)
|
||||
|
||||
|
||||
# Set optimizer and compile the model
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||
|
||||
# Define checkpoints and early stopping
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
|
@ -167,6 +197,9 @@ def main():
|
|||
)
|
||||
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
ImageSnapshots = ImageLoggingCallback(validation_dataset, LOG_DIR)
|
||||
|
||||
|
||||
# Custom garbage collection callback
|
||||
gc_callback = GarbageCollectorCallback()
|
||||
|
||||
|
@ -174,20 +207,37 @@ def main():
|
|||
|
||||
# Train the model
|
||||
LOGGER.info("Starting model training.")
|
||||
model.fit(
|
||||
MODEL.fit(
|
||||
training_dataset,
|
||||
epochs=EPOCHS,
|
||||
validation_data=validation_dataset,
|
||||
callbacks=[early_stop, checkpoint_callback, gc_callback]
|
||||
callbacks=[early_stop, checkpoint_callback, gc_callback, tensorboard_callback, ImageSnapshots]
|
||||
)
|
||||
LOGGER.info("Model training completed.")
|
||||
|
||||
save_model(model)
|
||||
|
||||
save_model()
|
||||
|
||||
def preMain():
|
||||
# Delete the existing logs directory and create a new one
|
||||
if os.path.exists(LOG_DIR):
|
||||
shutil.rmtree(LOG_DIR)
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# Start TensorBoard as a subprocess
|
||||
LOGGER.info("Running tensorboard at: http://localhost:6006/")
|
||||
tensorboard_process = subprocess.Popen(['tensorboard', '--logdir', './logs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid)
|
||||
return tensorboard_process
|
||||
|
||||
if __name__ == "__main__":
|
||||
clear_screen()
|
||||
|
||||
tensorboard_process = preMain()
|
||||
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Unexpected error during training: {e}")
|
||||
raise
|
||||
raise
|
||||
finally:
|
||||
# Ensure TensorBoard process is terminated when main script ends
|
||||
os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers
|
||||
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
|
||||
from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH
|
||||
|
||||
|
@ -28,36 +29,6 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
|||
return np.concatenate(combined, axis=-1)
|
||||
|
||||
|
||||
def process_video(video):
|
||||
base_dir = os.path.dirname("test_data/validation/validation.json")
|
||||
|
||||
cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"]))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"]))
|
||||
|
||||
compressed_frames = []
|
||||
uncompressed_frames = []
|
||||
|
||||
while True:
|
||||
ret_compressed, compressed_frame = cap_compressed.read()
|
||||
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
|
||||
|
||||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
CRF = scale_crf(video["crf"])
|
||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
|
||||
compressed_frames.append(compressed_combined)
|
||||
uncompressed_frames.append(uncompressed_combined)
|
||||
|
||||
cap_compressed.release()
|
||||
cap_uncompressed.release()
|
||||
|
||||
return uncompressed_frames, compressed_frames
|
||||
|
||||
|
||||
def frame_generator(videos, max_frames=None):
|
||||
base_dir = "test_data/validation/"
|
||||
|
@ -76,10 +47,10 @@ def frame_generator(videos, max_frames=None):
|
|||
CRF = scale_crf(video["crf"])
|
||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
validation = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
training = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
|
||||
yield uncompressed_combined, compressed_combined
|
||||
yield training, validation
|
||||
|
||||
frame_count += 1
|
||||
if max_frames is not None and frame_count >= max_frames:
|
||||
|
@ -104,7 +75,7 @@ def create_dataset(videos, batch_size, max_frames=None):
|
|||
output_signature=output_signature
|
||||
)
|
||||
|
||||
dataset = dataset.shuffle(100).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
|
||||
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) #.prefetch(tf.data.experimental.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
|
@ -113,29 +84,36 @@ def create_dataset(videos, batch_size, max_frames=None):
|
|||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
LOGGER.debug("Initializing VideoCompressionModel.")
|
||||
|
||||
# Input shape (includes channels for CRF and SPEED_PRESET)
|
||||
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
|
||||
|
||||
input_shape = (None, None, NUM_COLOUR_CHANNELS + 2)
|
||||
|
||||
# Encoder part of the model
|
||||
self.encoder = tf.keras.Sequential([
|
||||
tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram),
|
||||
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
|
||||
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.MaxPooling2D((2, 2), padding='same')
|
||||
layers.InputLayer(input_shape=input_shape),
|
||||
layers.Conv2D(64, (3, 3), padding='same'),
|
||||
#layers.BatchNormalization(),
|
||||
layers.LeakyReLU(),
|
||||
layers.MaxPooling2D((2, 2), padding='same'),
|
||||
layers.SeparableConv2D(32, (3, 3), padding='same'), # Using Separable Convolution
|
||||
#layers.BatchNormalization(),
|
||||
layers.LeakyReLU(),
|
||||
layers.MaxPooling2D((2, 2), padding='same')
|
||||
])
|
||||
|
||||
# Decoder part of the model
|
||||
self.decoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS, (3, 3), activation='sigmoid', padding='same')
|
||||
layers.Conv2DTranspose(32, (3, 3), padding='same'),
|
||||
#layers.BatchNormalization(),
|
||||
layers.LeakyReLU(),
|
||||
layers.Conv2DTranspose(64, (3, 3), dilation_rate=2, padding='same'), # Using Dilated Convolution
|
||||
#layers.BatchNormalization(),
|
||||
layers.LeakyReLU(),
|
||||
# Use Sub-Pixel Convolutional Layer
|
||||
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 16, (3, 3), padding='same'), # 16 times the number of color channels
|
||||
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=4)) # Sub-Pixel Convolutional Layer with block_size=4
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
return self.decoder(self.encoder(inputs))
|
||||
encoded = self.encoder(inputs)
|
||||
return self.decoder(encoded)
|
||||
|
||||
|
||||
|
|
Reference in a new issue