Update
This commit is contained in:
parent
9cecaeb9d6
commit
4d29fffba1
4 changed files with 126 additions and 67 deletions
|
@ -23,6 +23,7 @@ def psnr(y_true, y_pred):
|
||||||
#LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}")
|
#LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}")
|
||||||
max_pixel = 1.0
|
max_pixel = 1.0
|
||||||
mse = K.mean(K.square(y_pred - y_true))
|
mse = K.mean(K.square(y_pred - y_true))
|
||||||
|
mse = tf.cast(mse, tf.float32) # Cast mse to tf.float32
|
||||||
return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0)
|
return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,6 +38,14 @@ def combined(y_true, y_pred):
|
||||||
def combined_loss(y_true, y_pred):
|
def combined_loss(y_true, y_pred):
|
||||||
return -combined(y_true, y_pred) # The goal is to maximize the combined value
|
return -combined(y_true, y_pred) # The goal is to maximize the combined value
|
||||||
|
|
||||||
|
# Option 1: Weight more towards PSNR
|
||||||
|
def combined_loss_weighted_psnr(y_true, y_pred):
|
||||||
|
return -0.7 * psnr(y_true, y_pred) - 0.3 * ssim(y_true, y_pred)
|
||||||
|
|
||||||
|
# Option 2: Weight more towards SSIM
|
||||||
|
def combined_loss_weighted_ssim(y_true, y_pred):
|
||||||
|
return -0.3 * psnr(y_true, y_pred) - 0.7 * ssim(y_true, y_pred)
|
||||||
|
|
||||||
|
|
||||||
def detect_noise(image, threshold=15):
|
def detect_noise(image, threshold=15):
|
||||||
# Convert to grayscale if it's a color image
|
# Convert to grayscale if it's a color image
|
||||||
|
@ -67,6 +76,8 @@ def preprocess_frame(frame, resize=True, scale=True):
|
||||||
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
|
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
|
||||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
|
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
if scale:
|
if scale:
|
||||||
# Scale frame to [0, 1]
|
# Scale frame to [0, 1]
|
||||||
frame = frame / 255.0
|
frame = frame / 255.0
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
# gobalVars.py
|
# gobalVars.py
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
import log
|
import log
|
||||||
import platform
|
import platform
|
||||||
import os
|
|
||||||
|
|
||||||
LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True)
|
LOGGER = log.Logger(level="TRACE", logfile="training.log", reset_logfile=True)
|
||||||
|
|
||||||
|
@ -13,6 +17,7 @@ NUM_COLOUR_CHANNELS = 3
|
||||||
WIDTH = 640
|
WIDTH = 640
|
||||||
HEIGHT = 360
|
HEIGHT = 360
|
||||||
MAX_FRAMES = 0
|
MAX_FRAMES = 0
|
||||||
|
DATATYPE = tf.float16
|
||||||
|
|
||||||
def clear_screen():
|
def clear_screen():
|
||||||
system_name = platform.system()
|
system_name = platform.system()
|
||||||
|
|
|
@ -16,7 +16,7 @@ import signal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from featureExtraction import combined, combined_loss, psnr, ssim
|
from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, ssim
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ DECAY_RATE = 0.9
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 10
|
||||||
RANDOM_SEED = 4576
|
RANDOM_SEED = 3545
|
||||||
MODEL = None
|
MODEL = None
|
||||||
LOG_DIR = './logs'
|
LOG_DIR = './logs'
|
||||||
|
|
||||||
|
@ -66,27 +66,17 @@ class ImageLoggingCallback(Callback):
|
||||||
return np.stack(converted, axis=0)
|
return np.stack(converted, axis=0)
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, logs=None):
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
random_idx = np.random.randint(0, MAX_FRAMES - 1)
|
# Get the first batch from the validation dataset
|
||||||
|
validation_data = next(iter(self.validation_dataset.take(1)))
|
||||||
|
|
||||||
validation_data = None
|
# Extract the inputs from the batch_input_images dictionary
|
||||||
dataset_size = 0 # to keep track of the dataset size
|
actual_images = validation_data[0]['image']
|
||||||
|
batch_gt_labels = validation_data[1]
|
||||||
|
|
||||||
# Loop through the dataset until the chosen index
|
actual_images = np.clip(actual_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
||||||
for i, data in enumerate(self.validation_dataset):
|
|
||||||
if i == random_idx:
|
|
||||||
validation_data = data
|
|
||||||
break
|
|
||||||
dataset_size += 1
|
|
||||||
|
|
||||||
if validation_data is None:
|
|
||||||
print(f"Random index exceeds validation dataset size: {dataset_size}. Using last available data.")
|
|
||||||
validation_data = data # assigning the last data seen in the loop to validation_data
|
|
||||||
|
|
||||||
batch_input_images, batch_gt_labels = validation_data
|
|
||||||
|
|
||||||
batch_input_images = np.clip(batch_input_images[:, :, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
|
||||||
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
|
batch_gt_labels = np.clip(batch_gt_labels * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Providing all three inputs to the model for prediction
|
||||||
reconstructed_frame = MODEL.predict(validation_data[0])
|
reconstructed_frame = MODEL.predict(validation_data[0])
|
||||||
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
|
reconstructed_frame = np.clip(reconstructed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
@ -94,13 +84,9 @@ class ImageLoggingCallback(Callback):
|
||||||
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
||||||
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
||||||
|
|
||||||
batch_input_images = self.convert_images(batch_input_images)
|
|
||||||
batch_gt_labels = self.convert_images(batch_gt_labels)
|
|
||||||
reconstructed_frame = self.convert_images(reconstructed_frame)
|
|
||||||
|
|
||||||
# Log images to TensorBoard
|
# Log images to TensorBoard
|
||||||
with self.writer.as_default():
|
with self.writer.as_default():
|
||||||
tf.summary.image("Input Images", batch_input_images, step=epoch, max_outputs=1)
|
tf.summary.image("Input Images", actual_images, step=epoch, max_outputs=1)
|
||||||
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
|
tf.summary.image("Ground Truth Labels", batch_gt_labels, step=epoch, max_outputs=1)
|
||||||
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
|
tf.summary.image("Reconstructed Frame", reconstructed_frame, step=epoch, max_outputs=3)
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
|
@ -196,7 +182,7 @@ def main():
|
||||||
|
|
||||||
# Set optimizer and compile the model
|
# Set optimizer and compile the model
|
||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||||
MODEL.compile(loss=combined_loss, optimizer=optimizer, metrics=[psnr, ssim, combined])
|
MODEL.compile(loss=combined_loss_weighted_psnr, optimizer=optimizer, metrics=[psnr, ssim, combined])
|
||||||
|
|
||||||
# Define checkpoints and early stopping
|
# Define checkpoints and early stopping
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
|
|
|
@ -7,27 +7,22 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
|
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
|
||||||
from globalVars import LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES
|
from globalVars import DATATYPE, LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES
|
||||||
|
|
||||||
|
if DATATYPE == tf.float16:
|
||||||
|
from tensorflow.keras.mixed_precision import Policy
|
||||||
|
|
||||||
|
policy = Policy('mixed_float16')
|
||||||
|
tf.keras.mixed_precision.set_global_policy(policy)
|
||||||
|
|
||||||
|
|
||||||
#from tensorflow.keras.mixed_precision import Policy
|
def is_black(frame, threshold=10):
|
||||||
|
"""Check if a frame is mostly black."""
|
||||||
|
return np.mean(frame) < threshold
|
||||||
|
|
||||||
#policy = Policy('mixed_float16')
|
|
||||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
|
||||||
|
|
||||||
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
|
||||||
processed_frame = preprocess_frame(frame, resize)
|
|
||||||
height, width, _ = processed_frame.shape
|
|
||||||
|
|
||||||
combined = [processed_frame]
|
|
||||||
|
|
||||||
if include_controls:
|
|
||||||
crf_array = np.full((height, width, 1), crf)
|
|
||||||
speed_array = np.full((height, width, 1), speed)
|
|
||||||
combined.extend([crf_array, speed_array])
|
|
||||||
|
|
||||||
return np.concatenate(combined, axis=-1)
|
|
||||||
|
|
||||||
|
def combine_batch(frame, resize=True):
|
||||||
|
return preprocess_frame(frame, resize)
|
||||||
|
|
||||||
|
|
||||||
def frame_generator(videos, max_frames=None):
|
def frame_generator(videos, max_frames=None):
|
||||||
|
@ -44,13 +39,17 @@ def frame_generator(videos, max_frames=None):
|
||||||
if not ret_compressed or not ret_uncompressed:
|
if not ret_compressed or not ret_uncompressed:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Skip black frames
|
||||||
|
if is_black(compressed_frame) or is_black(uncompressed_frame):
|
||||||
|
continue
|
||||||
|
|
||||||
CRF = scale_crf(video["crf"])
|
CRF = scale_crf(video["crf"])
|
||||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||||
|
|
||||||
validation = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
validation_image = combine_batch(compressed_frame)
|
||||||
training = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
training_image = combine_batch(uncompressed_frame)
|
||||||
|
|
||||||
yield training, validation
|
yield ({'image': training_image, 'CRF': CRF, 'Speed': SPEED}, validation_image)
|
||||||
|
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
if max_frames is not None and frame_count >= max_frames:
|
if max_frames is not None and frame_count >= max_frames:
|
||||||
|
@ -60,62 +59,120 @@ def frame_generator(videos, max_frames=None):
|
||||||
cap_uncompressed.release()
|
cap_uncompressed.release()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(videos, batch_size, max_frames=None):
|
def create_dataset(videos, batch_size, max_frames=None):
|
||||||
# Determine the output signature by processing a single video to obtain its shape
|
# Determine the output signature by processing a single video to obtain its shape
|
||||||
video_generator_instance = frame_generator(videos, max_frames)
|
video_generator_instance = frame_generator(videos, max_frames)
|
||||||
sample_uncompressed, sample_compressed = next(video_generator_instance)
|
sample_uncompressed, sample_compressed = next(video_generator_instance)
|
||||||
|
|
||||||
output_signature = (
|
output_signature = (
|
||||||
tf.TensorSpec(shape=tf.shape(sample_uncompressed), dtype=tf.float32),
|
{
|
||||||
tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=tf.float32)
|
'image': tf.TensorSpec(shape=tf.shape(sample_uncompressed['image']), dtype=DATATYPE),
|
||||||
|
'CRF': tf.TensorSpec(shape=(), dtype=DATATYPE),
|
||||||
|
'Speed': tf.TensorSpec(shape=(), dtype=DATATYPE),
|
||||||
|
},
|
||||||
|
tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=DATATYPE)
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = tf.data.Dataset.from_generator(
|
dataset = tf.data.Dataset.from_generator(
|
||||||
lambda: frame_generator(videos, max_frames), # Include max_frames argument through lambda
|
lambda: frame_generator(videos, max_frames),
|
||||||
output_signature=output_signature
|
output_signature=output_signature
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) #.prefetch(tf.data.experimental.AUTOTUNE)
|
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SeparableTranspose2D(layers.Layer):
|
||||||
|
def __init__(self, filters, kernel_size, strides=(1, 1), padding='same', **kwargs):
|
||||||
|
super(SeparableTranspose2D, self).__init__(**kwargs)
|
||||||
|
self.filters = filters
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.strides = strides
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
# Use UpSampling2D for resizing
|
||||||
|
self.upsample = layers.UpSampling2D(size=strides)
|
||||||
|
|
||||||
|
# Depthwise convolution
|
||||||
|
self.depthwise_conv = layers.DepthwiseConv2D(kernel_size=kernel_size, padding=padding)
|
||||||
|
|
||||||
|
# Pointwise convolution
|
||||||
|
self.pointwise_conv = layers.Conv2D(filters, kernel_size=(1, 1), padding=padding)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
x = self.upsample(inputs)
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = self.pointwise_conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VideoCompressionModel(tf.keras.Model):
|
class VideoCompressionModel(tf.keras.Model):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(VideoCompressionModel, self).__init__()
|
super(VideoCompressionModel, self).__init__()
|
||||||
input_shape = (None, None, NUM_COLOUR_CHANNELS + 2)
|
input_shape = (None, None, NUM_COLOUR_CHANNELS)
|
||||||
|
|
||||||
# Encoder part of the model
|
# Encoder part of the model
|
||||||
self.encoder = tf.keras.Sequential([
|
self.encoder = tf.keras.Sequential([
|
||||||
layers.InputLayer(input_shape=input_shape),
|
layers.InputLayer(input_shape=input_shape),
|
||||||
layers.Conv2D(32, (3, 3), padding='same'),
|
layers.SeparableConv2D(64, (3, 3), padding='same'),
|
||||||
layers.LeakyReLU(),
|
layers.LeakyReLU(),
|
||||||
layers.MaxPooling2D((2, 2), padding='same'),
|
layers.MaxPooling2D((2, 2), padding='same'),
|
||||||
layers.Dropout(0.4),
|
layers.SeparableConv2D(128, (3, 3), padding='same'),
|
||||||
layers.SeparableConv2D(16, (3, 3), padding='same'),
|
|
||||||
layers.LeakyReLU(),
|
layers.LeakyReLU(),
|
||||||
layers.MaxPooling2D((2, 2), padding='same'),
|
layers.MaxPooling2D((2, 2), padding='same'),
|
||||||
layers.Dropout(0.4),
|
|
||||||
])
|
])
|
||||||
|
|
||||||
# Decoder part of the model using Transposed Convolutions for upsampling
|
# Fully connected layers for processing CRF and Speed
|
||||||
|
self.dense_crf_speed = tf.keras.Sequential([
|
||||||
|
layers.Dense(64, activation='relu'),
|
||||||
|
layers.Dense(128, activation='relu'),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Decoder part of the model
|
||||||
self.decoder = tf.keras.Sequential([
|
self.decoder = tf.keras.Sequential([
|
||||||
layers.Conv2DTranspose(16, (3, 3), padding='same'),
|
SeparableTranspose2D(128, (3, 3), padding='same'),
|
||||||
layers.LeakyReLU(),
|
layers.LeakyReLU(),
|
||||||
layers.Dropout(0.4),
|
SeparableTranspose2D(64, (3, 3), padding='same'),
|
||||||
layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same'),
|
|
||||||
layers.LeakyReLU(),
|
layers.LeakyReLU(),
|
||||||
layers.Dropout(0.4),
|
layers.UpSampling2D((2, 2)),
|
||||||
layers.UpSampling2D((2, 2)),
|
layers.UpSampling2D((2, 2)),
|
||||||
layers.Conv2D(NUM_COLOUR_CHANNELS, (3, 3), padding='same', activation='sigmoid')
|
layers.Conv2D(NUM_COLOUR_CHANNELS, (3, 3), padding='same', activation='sigmoid')
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
#print(f"Input: {inputs.shape}")
|
# Extract the image, CRF, and Speed values from the inputs dictionary
|
||||||
encoded = self.encoder(inputs)
|
image = inputs['image']
|
||||||
#print(f"encoded: {encoded.shape}")
|
crf = inputs['CRF']
|
||||||
decoded = self.decoder(encoded)
|
speed = inputs['Speed']
|
||||||
#print(f"decoded: {decoded.shape}")
|
|
||||||
|
# CRF and Speed are 1D tensors with shape [batch_size]
|
||||||
|
# Concatenate them to create a [batch_size, 2] tensor
|
||||||
|
crf_speed_vector = tf.concat([tf.expand_dims(crf, -1), tf.expand_dims(speed, -1)], axis=-1)
|
||||||
|
|
||||||
|
# Process the combined crf_speed_vector through your dense layers
|
||||||
|
# This will produce a tensor with shape [batch_size, 128]
|
||||||
|
crf_speed_features = self.dense_crf_speed(crf_speed_vector)
|
||||||
|
|
||||||
|
# Reshape the tensor to match spatial dimensions
|
||||||
|
# New shape: [batch_size, 1, 1, 128]
|
||||||
|
crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128])
|
||||||
|
|
||||||
|
# Tile the tensor to match spatial dimensions of encoded tensor
|
||||||
|
# Tiled shape: [batch_size, 90, 160, 128]
|
||||||
|
crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1])
|
||||||
|
|
||||||
|
# Pass the image through the encoder
|
||||||
|
encoded = self.encoder(image)
|
||||||
|
|
||||||
|
# Concatenate the encoded tensor with the crf_speed_features tensor
|
||||||
|
combined_features = tf.concat([encoded, crf_speed_features], axis=-1)
|
||||||
|
|
||||||
|
# Pass the combined features through the decoder
|
||||||
|
decoded = self.decoder(combined_features)
|
||||||
|
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in a new issue