This repository has been archived on 2025-05-04. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
DeepEncode/video_compression_model.py
Jordon Brooks ba6c132c67 Added GC
The data set will now process frames from ALL videos
2023-08-17 23:42:06 +01:00

107 lines
4.5 KiB
Python

# video_compression_model.py
import gc
import os
import cv2
import numpy as np
import tensorflow as tf
from featureExtraction import preprocess_frame, scale_crf, scale_speed_preset
from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, PRESET_SPEED_CATEGORIES, WIDTH
#from tensorflow.keras.mixed_precision import Policy
#policy = Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
processed_frame = preprocess_frame(frame, resize)
height, width, _ = processed_frame.shape
combined = [processed_frame]
if include_controls:
crf_array = np.full((height, width, 1), crf)
speed_array = np.full((height, width, 1), speed)
combined.extend([crf_array, speed_array])
return np.concatenate(combined, axis=-1)
def data_generator(videos, batch_size):
base_dir = os.path.dirname("test_data/validation/validation.json")
while True:
# Lists to store the processed frames
compressed_frame_batch = [] # Input data (Target)
uncompressed_frame_batch = [] # Target data (Training)
# Get a list of video capture objects for all videos
caps_compressed = [cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) for video in videos]
caps_uncompressed = [cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) for video in videos]
# As long as any video can provide frames, keep running
while any(cap.isOpened() for cap in caps_compressed):
for idx, (cap_compressed, cap_uncompressed) in enumerate(zip(caps_compressed, caps_uncompressed)):
#print(f"(Video Change) Processing video {idx}") # Print statement to indicate video change
ret_compressed, compressed_frame = cap_compressed.read()
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
if not ret_compressed or not ret_uncompressed:
continue
CRF = scale_crf(videos[idx]["crf"])
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(videos[idx]["preset_speed"]))
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
compressed_frame_batch.append(compressed_combined)
uncompressed_frame_batch.append(uncompressed_combined)
if len(compressed_frame_batch) == batch_size:
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
compressed_frame_batch.clear()
uncompressed_frame_batch.clear()
# Close all video captures at the end
for cap in caps_compressed + caps_uncompressed:
cap.release()
cv2.destroyAllWindows()
# If there are frames left that don't fill a whole batch, send them anyway
if len(compressed_frame_batch) > 0:
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
class VideoCompressionModel(tf.keras.Model):
def __init__(self):
super(VideoCompressionModel, self).__init__()
LOGGER.debug("Initializing VideoCompressionModel.")
# Input shape (includes channels for CRF and SPEED_PRESET)
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
# Encoder part of the model
self.encoder = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
tf.keras.layers.MaxPooling2D((2, 2), padding='same')
])
# Decoder part of the model
self.decoder = tf.keras.Sequential([
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS, (3, 3), activation='sigmoid', padding='same')
])
def call(self, inputs):
return self.decoder(self.encoder(inputs))