This commit is contained in:
Jordon Brooks 2023-08-23 00:54:06 +01:00
parent f4512bba99
commit db43239b3d
5 changed files with 311 additions and 197 deletions

View file

@ -1,90 +1,145 @@
# DeepEncode.py
import os
import argparse
import cv2
import numpy as np
from featureExtraction import combined, preprocess_frame, psnr, scale_crf, scale_speed_preset, ssim
from globalVars import PRESET_SPEED_CATEGORIES
# Set TensorFlow log level before any other imports
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow as tf
import numpy as np
import cv2
from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim
from globalVars import PRESET_SPEED_CATEGORIES, clear_screen
from video_compression_model import VideoCompressionModel, combine_batch
# Constants
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 200 # Limit the number of frames processed
CRF = 51
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
# Load the trained model
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined})
# Load the uncompressed video
CRF = 10
SPEED = "ultrafast"
MODEL_PATH = 'models/model.tf'
UNCOMPRESSED_VIDEO_FILE = 'test_data/x264_crf-5_preset-veryslow.mkv'
DISPLAY_OUTPUT = False
CROP_DIMENSIONS = None
def parse_arguments():
global COMPRESSED_VIDEO_FILE, MAX_FRAMES, CRF, SPEED, MODEL_PATH, UNCOMPRESSED_VIDEO_FILE, DISPLAY_OUTPUT, CROP_DIMENSIONS
parser = argparse.ArgumentParser(description='Deep Encoding of Videos')
parser.add_argument('-o', '--compressed_video_file', default=COMPRESSED_VIDEO_FILE, help='Path to the compressed video file')
parser.add_argument('-m', '--max_frames', type=int, default=MAX_FRAMES, help='Maximum number of frames to process')
parser.add_argument('-c', '--crf', type=int, default=CRF, help='CRF value for video compression')
parser.add_argument('-s', '--speed', default=SPEED, choices=PRESET_SPEED_CATEGORIES, help='Video compression speed category')
parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model')
parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file')
parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen')
parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False)
args = parser.parse_args()
COMPRESSED_VIDEO_FILE = args.compressed_video_file
MAX_FRAMES = args.max_frames
CRF = args.crf
SPEED = args.speed
MODEL_PATH = args.model_path
UNCOMPRESSED_VIDEO_FILE = args.uncompressed_video_file
DISPLAY_OUTPUT = args.display_output
if not args.keep_black_bars:
CROP_DIMENSIONS = find_crop_dimensions(UNCOMPRESSED_VIDEO_FILE)
def crop_black_bars(frame):
# Convert to grayscale for easier processing
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Threshold the image to make everything below a certain gray value black, and everything else white
_, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
# Find the contours of the white regions
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Find the bounding box that contains all the contours
x_min = y_min = float('inf')
x_max = y_max = 0
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
return x_min, y_min, x_max, y_max
def find_crop_dimensions(video_file):
cap = cv2.VideoCapture(video_file)
while True:
ret, frame = cap.read()
if not ret:
print("Error: Unable to find a non-black frame.")
cap.release()
exit()
# Check if the frame is entirely black
if np.any(frame > 0):
x_min, y_min, x_max, y_max = crop_black_bars(frame)
cap.release()
return x_min, y_min, x_max, y_max
def load_frame_from_video(video_file, frame_num):
cap = cv2.VideoCapture(video_file)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
if not ret:
return None
cap.release()
return frame
return frame if ret else None
def predict_frame(uncompressed_frame):
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
#cv2.imshow("uncomp", uncompressed_frame)
scaled_crf = scale_crf(CRF)
scaled_speed = scale_speed_preset(SPEED)
def predict_frame(uncompressed_frame, model, crf, speed):
scaled_crf = scale_crf(crf)
scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed))
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
cv2.imshow("comp", compressed_frame)
cv2.waitKey(1)
return compressed_frame
compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0]
return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8)
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
height, width = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
fps = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, fps, (width, height), True)
if not out.isOpened():
print("Error: VideoWriter could not be opened.")
exit()
if MAX_FRAMES != 0 and total_frames > MAX_FRAMES:
total_frames = MAX_FRAMES
for i in range(total_frames):
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
compressed_frame = predict_frame(uncompressed_frame)
def main():
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss})
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
compressed_frame = cv2.resize(compressed_frame, (width, height))
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS))
#compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
#compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
out.write(compressed_frame)
cap.release()
#if i % 10 == 0: # Print progress every 10 frames
# print(f"Processed {i} / {total_frames} frames")
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, fps, (width, height), True)
out.release()
print("Compression completed.")
if not out.isOpened():
print("Error: VideoWriter could not be opened.")
exit()
for i in range(total_frames):
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
if CROP_DIMENSIONS:
x_min, y_min, x_max, y_max = CROP_DIMENSIONS
uncompressed_frame = uncompressed_frame[y_min:y_max, x_min:x_max]
compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED)
compressed_frame = cv2.resize(compressed_frame, (width, height))
out.write(compressed_frame)
if DISPLAY_OUTPUT:
cv2.imshow('Compressed Video', compressed_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
out.release()
print("Compression completed.")
if __name__ == '__main__':
clear_screen()
parse_arguments()
main()

View file

@ -9,51 +9,21 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow as tf
from tensorflow.keras import backend as K
from globalVars import HEIGHT, NUM_PRESET_SPEEDS, WIDTH
from globalVars import HEIGHT, LOGGER, NUM_PRESET_SPEEDS, WIDTH
def scale_crf(crf):
return crf / 51
def scale_speed_preset(speed):
return speed / NUM_PRESET_SPEEDS
def extract_edge_features(frame):
"""
Extract edge features using Canny edge detection.
Args:
- frame (ndarray): Image frame.
Returns:
- ndarray: Edge feature map.
"""
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
return edges.astype(np.float32) / 255.0
def extract_histogram_features(frame, bins=64):
"""
Extract histogram features from a frame with 3 channels.
Args:
- frame (ndarray): Image frame with shape (height, width, 3).
- bins (int): Number of bins for the histogram.
Returns:
- ndarray: Normalized histogram feature vector.
"""
feature_vector = []
for channel in range(3):
histogram, _ = np.histogram(frame[:,:,channel].flatten(), bins=bins, range=[0, 255])
normalized_histogram = histogram.astype(np.float32) / frame[:,:,channel].size
feature_vector.extend(normalized_histogram)
return np.array(feature_vector)
def psnr(y_true, y_pred):
#LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}")
max_pixel = 1.0
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
mse = K.mean(K.square(y_pred - y_true))
return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0)
def ssim(y_true, y_pred):
@ -64,14 +34,41 @@ def combined(y_true, y_pred):
return (psnr(y_true, y_pred) + ssim(y_true, y_pred)) / 2
def preprocess_frame(frame, resize=True):
#Preprocesses a single frame, cropping it if needed
def combined_loss(y_true, y_pred):
return -combined(y_true, y_pred) # The goal is to maximize the combined value
def detect_noise(image, threshold=15):
# Convert to grayscale if it's a color image
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Compute the standard deviation
std_dev = np.std(image)
# If the standard deviation is higher than a threshold, it might be considered noisy
return std_dev > threshold
def frame_difference(frame1, frame2):
# Ensure both frames are of the same size and type
if frame1.shape != frame2.shape:
raise ValueError("Frames must have the same dimensions and number of channels")
# Calculate the absolute difference between the frames
difference = cv2.absdiff(frame1, frame2)
return difference
def preprocess_frame(frame, resize=True, scale=True):
# Check frame dimensions and resize if necessary
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
if scale:
# Scale frame to [0, 1]
compressed_frame = frame / 255.0
frame = frame / 255.0
return compressed_frame
return frame

View file

@ -1,6 +1,9 @@
# gobalVars.py
import json
import log
import platform
import os
LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True)
@ -9,4 +12,35 @@ NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
NUM_COLOUR_CHANNELS = 3
WIDTH = 640
HEIGHT = 360
MAX_FRAMES = 0
MAX_FRAMES = 0
def clear_screen():
system_name = platform.system()
if system_name == "Windows":
os.system('cls')
else:
os.system('clear')
def load_video_metadata(list_path):
"""
Load video metadata from a JSON file.
Args:
- json_path (str): Path to the JSON file containing video metadata.
Returns:
- list: List of dictionaries, each containing video details.
"""
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
try:
with open(list_path, "r") as json_file:
file = json.load(json_file)
LOGGER.trace(f"load_video_metadata returning: {file}")
return file
except FileNotFoundError:
LOGGER.error(f"Metadata file {list_path} not found.")
raise
except json.JSONDecodeError:
LOGGER.error(f"Error decoding JSON from {list_path}.")
raise

View file

@ -7,17 +7,25 @@ TODO:
"""
import argparse
import json
import os
import random
import shutil
import cv2
import subprocess
import signal
from featureExtraction import combined, psnr, ssim
import numpy as np
from featureExtraction import combined, combined_loss, psnr, ssim
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import gc
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback, TensorBoard
from tensorflow.keras import backend as K
from tensorflow.summary import image as tf_image_summary
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
@ -30,7 +38,7 @@ if gpus:
from video_compression_model import VideoCompressionModel, create_dataset
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_video_metadata
# Constants
BATCH_SIZE = 25
@ -41,50 +49,71 @@ DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10
RANDOM_SEED = 4576
MODEL = None
LOG_DIR = './logs'
class ImageLoggingCallback(Callback):
def __init__(self, validation_dataset, log_dir):
super().__init__()
self.validation_dataset = validation_dataset
self.log_dir = log_dir
self.writer = tf.summary.create_file_writer(self.log_dir)
def convert_images(self, images):
converted = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
return np.stack(converted, axis=0)
def on_epoch_end(self, epoch, logs=None):
itter = iter(self.validation_dataset)
random_idx = np.random.randint(0, BATCH_SIZE)
# Loop through the dataset until the chosen index
for i, data in enumerate(self.validation_dataset):
if i == random_idx:
validation_data = data
break
batch_input_images, batch_gt_labels = validation_data
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
reconstructed_frame = MODEL.predict(validation_data[0])
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
batch_input_images = self.convert_images(batch_input_images)
batch_gt_labels = self.convert_images(batch_gt_labels)
reconstructed_frame = self.convert_images(reconstructed_frame)
# Log images to TensorBoard
with self.writer.as_default():
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
self.writer.flush()
class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"GC")
gc.collect()
def save_model(model):
def save_model():
try:
LOGGER.debug("Attempting to save the model.")
os.makedirs("models", exist_ok=True)
model.save(MODEL_SAVE_FILE, save_format='tf')
MODEL.save(MODEL_SAVE_FILE, save_format='tf')
LOGGER.info("Model saved successfully!")
except Exception as e:
LOGGER.error(f"Error saving the model: {e}")
raise
def load_video_metadata(list_path):
"""
Load video metadata from a JSON file.
Args:
- json_path (str): Path to the JSON file containing video metadata.
Returns:
- list: List of dictionaries, each containing video details.
"""
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
try:
with open(list_path, "r") as json_file:
file = json.load(json_file)
LOGGER.trace(f"load_video_metadata returning: {file}")
return file
except FileNotFoundError:
LOGGER.error(f"Metadata file {list_path} not found.")
raise
except json.JSONDecodeError:
LOGGER.error(f"Error decoding JSON from {list_path}.")
raise
def main():
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE, MAX_FRAMES, DECAY_STEPS, DECAY_RATE, MODEL
# Argument parsing
parser = argparse.ArgumentParser(description="Train the video compression model.")
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
@ -119,11 +148,10 @@ def main():
# Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json")
# Specify the random seed
random_seed = 4576 # You can change this to any desired value
tf.random.set_seed(RANDOM_SEED)
# Shuffle the data using the specified seed
random.shuffle(all_videos, random.seed(random_seed))
random.shuffle(all_videos, random.seed(RANDOM_SEED))
# Split into training and validation
split_index = int(0.6 * len(all_videos))
@ -136,12 +164,14 @@ def main():
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
tensorboard_callback = TensorBoard(log_dir=LOG_DIR, histogram_freq=1, profile_batch=0, write_graph=True, update_freq='epoch')
if args.continue_training:
model = tf.keras.models.load_model(args.continue_training)
MODEL = tf.keras.models.load_model(args.continue_training)
else:
model = VideoCompressionModel()
MODEL = VideoCompressionModel()
# Define exponential decay schedule
@ -149,13 +179,13 @@ def main():
initial_learning_rate=LEARNING_RATE,
decay_steps=DECAY_STEPS,
decay_rate=DECAY_RATE,
staircase=False
staircase=True
)
# Set optimizer and compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='mse', optimizer=optimizer, metrics=[psnr, ssim, combined])
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
# Define checkpoints and early stopping
checkpoint_callback = ModelCheckpoint(
@ -167,6 +197,9 @@ def main():
)
early_stop = EarlyStopping(monitor='val_combined', mode='max', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
ImageSnapshots = ImageLoggingCallback(validation_dataset, LOG_DIR)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()
@ -174,20 +207,37 @@ def main():
# Train the model
LOGGER.info("Starting model training.")
model.fit(
MODEL.fit(
training_dataset,
epochs=EPOCHS,
validation_data=validation_dataset,
callbacks=[early_stop, checkpoint_callback, gc_callback]
callbacks=[early_stop, checkpoint_callback, gc_callback, tensorboard_callback, ImageSnapshots]
)
LOGGER.info("Model training completed.")
save_model(model)
save_model()
def preMain():
# Delete the existing logs directory and create a new one
if os.path.exists(LOG_DIR):
shutil.rmtree(LOG_DIR)
os.makedirs(LOG_DIR, exist_ok=True)
# Start TensorBoard as a subprocess
LOGGER.info("Running tensorboard at: http://localhost:6006/")
tensorboard_process = subprocess.Popen(['tensorboard', '--logdir', './logs'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid)
return tensorboard_process
if __name__ == "__main__":
clear_screen()
tensorboard_process = preMain()
try:
main()
except Exception as e:
LOGGER.error(f"Unexpected error during training: {e}")
raise
raise
finally:
# Ensure TensorBoard process is terminated when main script ends
os.killpg(os.getpgid(tensorboard_process.pid), signal.SIGTERM)

View file

@ -5,6 +5,7 @@ import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH
@ -28,36 +29,6 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True):
return np.concatenate(combined, axis=-1)
def process_video(video):
base_dir = os.path.dirname("test_data/validation/validation.json")
cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"]))
cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"]))
compressed_frames = []
uncompressed_frames = []
while True:
ret_compressed, compressed_frame = cap_compressed.read()
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
if not ret_compressed or not ret_uncompressed:
break
CRF = scale_crf(video["crf"])
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
compressed_frames.append(compressed_combined)
uncompressed_frames.append(uncompressed_combined)
cap_compressed.release()
cap_uncompressed.release()
return uncompressed_frames, compressed_frames
def frame_generator(videos, max_frames=None):
base_dir = "test_data/validation/"
@ -76,10 +47,10 @@ def frame_generator(videos, max_frames=None):
CRF = scale_crf(video["crf"])
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
uncompressed_combined = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
validation = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
training = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
yield uncompressed_combined, compressed_combined
yield training, validation
frame_count += 1
if max_frames is not None and frame_count >= max_frames:
@ -104,7 +75,7 @@ def create_dataset(videos, batch_size, max_frames=None):
output_signature=output_signature
)
dataset = dataset.shuffle(100).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) #.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
@ -113,29 +84,36 @@ def create_dataset(videos, batch_size, max_frames=None):
class VideoCompressionModel(tf.keras.Model):
def __init__(self):
super(VideoCompressionModel, self).__init__()
LOGGER.debug("Initializing VideoCompressionModel.")
# Input shape (includes channels for CRF and SPEED_PRESET)
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
input_shape = (None, None, NUM_COLOUR_CHANNELS + 2)
# Encoder part of the model
self.encoder = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
tf.keras.layers.MaxPooling2D((2, 2), padding='same')
layers.InputLayer(input_shape=input_shape),
layers.Conv2D(64, (3, 3), padding='same'),
#layers.BatchNormalization(),
layers.LeakyReLU(),
layers.MaxPooling2D((2, 2), padding='same'),
layers.SeparableConv2D(32, (3, 3), padding='same'), # Using Separable Convolution
#layers.BatchNormalization(),
layers.LeakyReLU(),
layers.MaxPooling2D((2, 2), padding='same')
])
# Decoder part of the model
self.decoder = tf.keras.Sequential([
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS, (3, 3), activation='sigmoid', padding='same')
layers.Conv2DTranspose(32, (3, 3), padding='same'),
#layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Conv2DTranspose(64, (3, 3), dilation_rate=2, padding='same'), # Using Dilated Convolution
#layers.BatchNormalization(),
layers.LeakyReLU(),
# Use Sub-Pixel Convolutional Layer
layers.Conv2DTranspose(NUM_COLOUR_CHANNELS * 16, (3, 3), padding='same'), # 16 times the number of color channels
layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=4)) # Sub-Pixel Convolutional Layer with block_size=4
])
def call(self, inputs):
return self.decoder(self.encoder(inputs))
encoded = self.encoder(inputs)
return self.decoder(encoded)