Update
This commit is contained in:
parent
9cecaeb9d6
commit
4d29fffba1
4 changed files with 126 additions and 67 deletions
|
@ -7,28 +7,23 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
from tensorflow.keras import layers
|
||||
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
|
||||
from globalVars import LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES
|
||||
from globalVars import DATATYPE, LOGGER, NUM_COLOUR_CHANNELS, PRESET_SPEED_CATEGORIES
|
||||
|
||||
if DATATYPE == tf.float16:
|
||||
from tensorflow.keras.mixed_precision import Policy
|
||||
|
||||
#from tensorflow.keras.mixed_precision import Policy
|
||||
|
||||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
||||
processed_frame = preprocess_frame(frame, resize)
|
||||
height, width, _ = processed_frame.shape
|
||||
policy = Policy('mixed_float16')
|
||||
tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
combined = [processed_frame]
|
||||
|
||||
if include_controls:
|
||||
crf_array = np.full((height, width, 1), crf)
|
||||
speed_array = np.full((height, width, 1), speed)
|
||||
combined.extend([crf_array, speed_array])
|
||||
|
||||
return np.concatenate(combined, axis=-1)
|
||||
def is_black(frame, threshold=10):
|
||||
"""Check if a frame is mostly black."""
|
||||
return np.mean(frame) < threshold
|
||||
|
||||
|
||||
def combine_batch(frame, resize=True):
|
||||
return preprocess_frame(frame, resize)
|
||||
|
||||
|
||||
def frame_generator(videos, max_frames=None):
|
||||
base_dir = "test_data/validation/"
|
||||
|
@ -44,13 +39,17 @@ def frame_generator(videos, max_frames=None):
|
|||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
# Skip black frames
|
||||
if is_black(compressed_frame) or is_black(uncompressed_frame):
|
||||
continue
|
||||
|
||||
CRF = scale_crf(video["crf"])
|
||||
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
|
||||
|
||||
validation = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
training = combine_batch(uncompressed_frame, 10, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
validation_image = combine_batch(compressed_frame)
|
||||
training_image = combine_batch(uncompressed_frame)
|
||||
|
||||
yield training, validation
|
||||
yield ({'image': training_image, 'CRF': CRF, 'Speed': SPEED}, validation_image)
|
||||
|
||||
frame_count += 1
|
||||
if max_frames is not None and frame_count >= max_frames:
|
||||
|
@ -60,62 +59,120 @@ def frame_generator(videos, max_frames=None):
|
|||
cap_uncompressed.release()
|
||||
|
||||
|
||||
|
||||
def create_dataset(videos, batch_size, max_frames=None):
|
||||
# Determine the output signature by processing a single video to obtain its shape
|
||||
video_generator_instance = frame_generator(videos, max_frames)
|
||||
sample_uncompressed, sample_compressed = next(video_generator_instance)
|
||||
|
||||
output_signature = (
|
||||
tf.TensorSpec(shape=tf.shape(sample_uncompressed), dtype=tf.float32),
|
||||
tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=tf.float32)
|
||||
{
|
||||
'image': tf.TensorSpec(shape=tf.shape(sample_uncompressed['image']), dtype=DATATYPE),
|
||||
'CRF': tf.TensorSpec(shape=(), dtype=DATATYPE),
|
||||
'Speed': tf.TensorSpec(shape=(), dtype=DATATYPE),
|
||||
},
|
||||
tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=DATATYPE)
|
||||
)
|
||||
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
lambda: frame_generator(videos, max_frames), # Include max_frames argument through lambda
|
||||
lambda: frame_generator(videos, max_frames),
|
||||
output_signature=output_signature
|
||||
)
|
||||
|
||||
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) #.prefetch(tf.data.experimental.AUTOTUNE)
|
||||
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
class SeparableTranspose2D(layers.Layer):
|
||||
def __init__(self, filters, kernel_size, strides=(1, 1), padding='same', **kwargs):
|
||||
super(SeparableTranspose2D, self).__init__(**kwargs)
|
||||
self.filters = filters
|
||||
self.kernel_size = kernel_size
|
||||
self.strides = strides
|
||||
self.padding = padding
|
||||
|
||||
# Use UpSampling2D for resizing
|
||||
self.upsample = layers.UpSampling2D(size=strides)
|
||||
|
||||
# Depthwise convolution
|
||||
self.depthwise_conv = layers.DepthwiseConv2D(kernel_size=kernel_size, padding=padding)
|
||||
|
||||
# Pointwise convolution
|
||||
self.pointwise_conv = layers.Conv2D(filters, kernel_size=(1, 1), padding=padding)
|
||||
|
||||
def call(self, inputs):
|
||||
x = self.upsample(inputs)
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
input_shape = (None, None, NUM_COLOUR_CHANNELS + 2)
|
||||
input_shape = (None, None, NUM_COLOUR_CHANNELS)
|
||||
|
||||
# Encoder part of the model
|
||||
self.encoder = tf.keras.Sequential([
|
||||
layers.InputLayer(input_shape=input_shape),
|
||||
layers.Conv2D(32, (3, 3), padding='same'),
|
||||
layers.SeparableConv2D(64, (3, 3), padding='same'),
|
||||
layers.LeakyReLU(),
|
||||
layers.MaxPooling2D((2, 2), padding='same'),
|
||||
layers.Dropout(0.4),
|
||||
layers.SeparableConv2D(16, (3, 3), padding='same'),
|
||||
layers.SeparableConv2D(128, (3, 3), padding='same'),
|
||||
layers.LeakyReLU(),
|
||||
layers.MaxPooling2D((2, 2), padding='same'),
|
||||
layers.Dropout(0.4),
|
||||
])
|
||||
|
||||
# Decoder part of the model using Transposed Convolutions for upsampling
|
||||
# Fully connected layers for processing CRF and Speed
|
||||
self.dense_crf_speed = tf.keras.Sequential([
|
||||
layers.Dense(64, activation='relu'),
|
||||
layers.Dense(128, activation='relu'),
|
||||
])
|
||||
|
||||
# Decoder part of the model
|
||||
self.decoder = tf.keras.Sequential([
|
||||
layers.Conv2DTranspose(16, (3, 3), padding='same'),
|
||||
SeparableTranspose2D(128, (3, 3), padding='same'),
|
||||
layers.LeakyReLU(),
|
||||
layers.Dropout(0.4),
|
||||
layers.Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same'),
|
||||
SeparableTranspose2D(64, (3, 3), padding='same'),
|
||||
layers.LeakyReLU(),
|
||||
layers.Dropout(0.4),
|
||||
layers.UpSampling2D((2, 2)),
|
||||
layers.UpSampling2D((2, 2)),
|
||||
layers.Conv2D(NUM_COLOUR_CHANNELS, (3, 3), padding='same', activation='sigmoid')
|
||||
])
|
||||
|
||||
|
||||
def call(self, inputs):
|
||||
#print(f"Input: {inputs.shape}")
|
||||
encoded = self.encoder(inputs)
|
||||
#print(f"encoded: {encoded.shape}")
|
||||
decoded = self.decoder(encoded)
|
||||
#print(f"decoded: {decoded.shape}")
|
||||
# Extract the image, CRF, and Speed values from the inputs dictionary
|
||||
image = inputs['image']
|
||||
crf = inputs['CRF']
|
||||
speed = inputs['Speed']
|
||||
|
||||
# CRF and Speed are 1D tensors with shape [batch_size]
|
||||
# Concatenate them to create a [batch_size, 2] tensor
|
||||
crf_speed_vector = tf.concat([tf.expand_dims(crf, -1), tf.expand_dims(speed, -1)], axis=-1)
|
||||
|
||||
# Process the combined crf_speed_vector through your dense layers
|
||||
# This will produce a tensor with shape [batch_size, 128]
|
||||
crf_speed_features = self.dense_crf_speed(crf_speed_vector)
|
||||
|
||||
# Reshape the tensor to match spatial dimensions
|
||||
# New shape: [batch_size, 1, 1, 128]
|
||||
crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128])
|
||||
|
||||
# Tile the tensor to match spatial dimensions of encoded tensor
|
||||
# Tiled shape: [batch_size, 90, 160, 128]
|
||||
crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1])
|
||||
|
||||
# Pass the image through the encoder
|
||||
encoded = self.encoder(image)
|
||||
|
||||
# Concatenate the encoded tensor with the crf_speed_features tensor
|
||||
combined_features = tf.concat([encoded, crf_speed_features], axis=-1)
|
||||
|
||||
# Pass the combined features through the decoder
|
||||
decoded = self.decoder(combined_features)
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
|
|
Reference in a new issue