This commit is contained in:
Jordon Brooks 2023-08-13 13:33:03 +01:00
parent 1d98bc84a2
commit fde856f3ec
6 changed files with 107 additions and 109 deletions

View file

@ -1,23 +1,50 @@
# video_compression_model.py
import os
import cv2
import numpy as np
import tensorflow as tf
from featureExtraction import preprocess_frame
from global_train import LOGGER
from globalVars import HEIGHT, LOGGER, WIDTH
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
NUM_CHANNELS = 3
WIDTH = 640
HEIGHT = 360
#PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
#NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
#from tensorflow.keras.mixed_precision import Policy
#policy = Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
def data_generator(videos, batch_size):
while True:
for video_details in videos:
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
cap = cv2.VideoCapture(video_path)
feature_batch = []
compressed_frame_batch = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
combined_feature, compressed_frame = preprocess_frame(frame)
feature_batch.append(combined_feature)
compressed_frame_batch.append(compressed_frame)
if len(feature_batch) == batch_size:
yield (np.array(feature_batch), np.array(compressed_frame_batch))
feature_batch = []
compressed_frame_batch = []
cap.release()
# If there are frames left that don't fill a whole batch, send them anyway
if len(feature_batch) > 0:
yield (np.array(feature_batch), np.array(compressed_frame_batch))
class VideoCompressionModel(tf.keras.Model):
def __init__(self):