This commit is contained in:
Jordon Brooks 2023-09-10 19:05:52 +01:00
parent 4d29fffba1
commit 8df4df7972
No known key found for this signature in database
GPG key ID: 83964894E5D98D57
3 changed files with 51 additions and 20 deletions

View file

@ -7,13 +7,13 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow as tf
from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim
from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, scale_crf, scale_speed_preset, ssim
from globalVars import PRESET_SPEED_CATEGORIES, clear_screen
from video_compression_model import VideoCompressionModel, combine_batch
# Constants
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 200 # Limit the number of frames processed
MAX_FRAMES = 0 # Limit the number of frames processed
CRF = 10
SPEED = "ultrafast"
MODEL_PATH = 'models/model.tf'
@ -33,7 +33,7 @@ def parse_arguments():
parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model')
parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file')
parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen')
parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False)
parser.add_argument('--keep_black_bars', action='store_false', help='Keep black bars from the video', default=True)
args = parser.parse_args()
@ -95,18 +95,35 @@ def load_frame_from_video(video_file, frame_num):
def predict_frame(uncompressed_frame, model, crf, speed):
# Scale the CRF and Speed values
scaled_crf = scale_crf(crf)
scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed))
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0]
return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8)
# Preprocess the frame
frame = combine_batch(uncompressed_frame, resize=False)
# Predict using the model
inputs = {
'image': np.expand_dims(frame, axis=0),
'CRF': np.array([scaled_crf]),
'Speed': np.array([scaled_speed])
}
compressed_frame = model.predict(inputs)[0]
# Post-process the output frame
return np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
def main():
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss})
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss, 'combined_loss_weighted_psnr': combined_loss_weighted_psnr})
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
if MAX_FRAMES > 0:
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
else:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS))
cap.release()
@ -127,6 +144,7 @@ def main():
compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED)
compressed_frame = cv2.resize(compressed_frame, (width, height))
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
out.write(compressed_frame)