Working GPU model

This commit is contained in:
Jordon Brooks 2023-07-30 11:49:19 +01:00
parent 5085c87300
commit dea59068fb
3 changed files with 190 additions and 108 deletions

View file

@ -1,61 +1,67 @@
# video_compression_model.py
import tensorflow as tf
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
#policy = tf.keras.mixed_precision.Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
NUM_CHANNELS = 3
class VideoCompressionModel(tf.keras.Model):
def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5, regularization_factor=1e-4):
def __init__(self):
super(VideoCompressionModel, self).__init__()
self.NUM_CHANNELS = NUM_CHANNELS
# Inputs
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
self.preset_speed_input = tf.keras.layers.InputLayer(name='preset_speed', input_shape=(1,))
self.uncompressed_frame_input = tf.keras.layers.InputLayer(name='uncompressed_frame', input_shape=(None, None, NUM_CHANNELS))
self.compressed_frame_input = tf.keras.layers.InputLayer(name='compressed_frame', input_shape=(None, None, NUM_CHANNELS))
# Regularization
self.regularizer = tf.keras.regularizers.l2(regularization_factor)
# Embedding layer for preset_speed
self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16, embeddings_regularizer=self.regularizer)
# Embedding for speed preset and FC layer for CRF and preset speed
self.embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16)
self.fc = tf.keras.layers.Dense(32, activation='relu')
# Encoder layers
self.encoder = tf.keras.Sequential([
tf.keras.layers.ZeroPadding2D(padding=((1, 1), (1, 1))), # Padding to preserve spatial dimensions
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(None, None, 2 * NUM_CHANNELS + 32)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D((2, 2)),
# Add more encoder layers as needed
tf.keras.layers.Dropout(0.3)
])
# Decoder layers
self.decoder = tf.keras.Sequential([
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
tf.keras.layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.UpSampling2D((2, 2)),
# Add more decoder layers as needed
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer), # Output layer for video frames
tf.keras.layers.Cropping2D(cropping=((1, 1), (1, 1))) # Adjust cropping to ensure dimensions match
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
])
def call(self, inputs):
frame = inputs["frame"]
crf = tf.expand_dims(inputs["crf"], -1)
preset_speed = inputs["preset_speed"]
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
# Convert preset_speed to embeddings
preset_embedding = self.preset_embedding(preset_speed)
preset_embedding = tf.keras.layers.Flatten()(preset_embedding)
# Concatenate crf and preset_embedding to frames
frame_shape = tf.shape(frame)
repeated_crf = tf.tile(tf.reshape(crf, (-1, 1, 1, 1)), [1, frame_shape[1], frame_shape[2], 1])
repeated_preset = tf.tile(tf.reshape(preset_embedding, (-1, 1, 1, 16)), [1, frame_shape[1], frame_shape[2], 1])
frame = tf.concat([tf.cast(frame, tf.float32), repeated_crf, repeated_preset], axis=-1)
# Convert frames to float32
uncompressed_frame = tf.cast(uncompressed_frame, tf.float32)
compressed_frame = tf.cast(compressed_frame, tf.float32)
# Encoding the frame
compressed_representation = self.encoder(frame)
# Integrate CRF and preset speed into the network
preset_speed_embedded = self.embedding(preset_speed)
crf_expanded = tf.expand_dims(crf, -1)
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)])
integrated_info = self.fc(integrated_info)
# Decoding to generate compressed frame
# Integrate the CRF and preset speed information into the frames as additional channels (features)
_, height, width, _ = uncompressed_frame.shape
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
# Merge uncompressed and compressed frames
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
compressed_representation = self.encoder(frames_merged)
reconstructed_frame = self.decoder(compressed_representation)
return reconstructed_frame