update
This commit is contained in:
parent
1d98bc84a2
commit
fde856f3ec
6 changed files with 107 additions and 109 deletions
|
@ -1,23 +1,50 @@
|
|||
# video_compression_model.py
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from featureExtraction import preprocess_frame
|
||||
|
||||
from global_train import LOGGER
|
||||
from globalVars import HEIGHT, LOGGER, WIDTH
|
||||
|
||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_CHANNELS = 3
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
#PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
#NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
|
||||
#from tensorflow.keras.mixed_precision import Policy
|
||||
|
||||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
while True:
|
||||
for video_details in videos:
|
||||
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
feature_batch = []
|
||||
compressed_frame_batch = []
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
combined_feature, compressed_frame = preprocess_frame(frame)
|
||||
|
||||
feature_batch.append(combined_feature)
|
||||
compressed_frame_batch.append(compressed_frame)
|
||||
|
||||
if len(feature_batch) == batch_size:
|
||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||
feature_batch = []
|
||||
compressed_frame_batch = []
|
||||
|
||||
cap.release()
|
||||
|
||||
# If there are frames left that don't fill a whole batch, send them anyway
|
||||
if len(feature_batch) > 0:
|
||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self):
|
||||
|
|
Reference in a new issue