76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
# video_compression_model.py
|
|
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from featureExtraction import preprocess_frame
|
|
|
|
from globalVars import HEIGHT, LOGGER, WIDTH
|
|
|
|
#PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
|
#NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
|
|
|
#from tensorflow.keras.mixed_precision import Policy
|
|
|
|
#policy = Policy('mixed_float16')
|
|
#tf.keras.mixed_precision.set_global_policy(policy)
|
|
|
|
def data_generator(videos, batch_size):
|
|
while True:
|
|
for video_details in videos:
|
|
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
feature_batch = []
|
|
compressed_frame_batch = []
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
combined_feature, compressed_frame = preprocess_frame(frame)
|
|
|
|
feature_batch.append(combined_feature)
|
|
compressed_frame_batch.append(compressed_frame)
|
|
|
|
if len(feature_batch) == batch_size:
|
|
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
|
feature_batch = []
|
|
compressed_frame_batch = []
|
|
|
|
cap.release()
|
|
|
|
# If there are frames left that don't fill a whole batch, send them anyway
|
|
if len(feature_batch) > 0:
|
|
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
|
|
|
class VideoCompressionModel(tf.keras.Model):
|
|
def __init__(self):
|
|
super(VideoCompressionModel, self).__init__()
|
|
LOGGER.debug("Initializing VideoCompressionModel.")
|
|
|
|
# Add an additional channel for the histogram features
|
|
input_shape_with_histogram = (HEIGHT, WIDTH, 2) # 1 channel for edges, 1 for histogram
|
|
|
|
self.encoder = tf.keras.Sequential([
|
|
tf.keras.layers.InputLayer(input_shape=input_shape_with_histogram),
|
|
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
|
|
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.MaxPooling2D((2, 2), padding='same')
|
|
])
|
|
|
|
self.decoder = tf.keras.Sequential([
|
|
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.UpSampling2D((2, 2)),
|
|
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
|
tf.keras.layers.UpSampling2D((2, 2)),
|
|
tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')
|
|
])
|
|
|
|
def call(self, inputs):
|
|
encoded = self.encoder(inputs)
|
|
decoded = self.decoder(encoded)
|
|
return decoded
|