This commit is contained in:
Jordon Brooks 2023-09-10 01:20:10 +01:00
parent 9cecaeb9d6
commit 4d29fffba1
No known key found for this signature in database
GPG key ID: 83964894E5D98D57
4 changed files with 126 additions and 67 deletions

View file

@ -23,6 +23,7 @@ def psnr(y_true, y_pred):
#LOGGER.info(f"[psnr function] y_true: {y_true.shape}, y_pred: {y_pred.shape}")
max_pixel = 1.0
mse = K.mean(K.square(y_pred - y_true))
mse = tf.cast(mse, tf.float32) # Cast mse to tf.float32
return 20.0 * K.log(max_pixel / K.sqrt(mse)) / K.log(10.0)
@ -37,6 +38,14 @@ def combined(y_true, y_pred):
def combined_loss(y_true, y_pred):
return -combined(y_true, y_pred) # The goal is to maximize the combined value
# Option 1: Weight more towards PSNR
def combined_loss_weighted_psnr(y_true, y_pred):
return -0.7 * psnr(y_true, y_pred) - 0.3 * ssim(y_true, y_pred)
# Option 2: Weight more towards SSIM
def combined_loss_weighted_ssim(y_true, y_pred):
return -0.3 * psnr(y_true, y_pred) - 0.7 * ssim(y_true, y_pred)
def detect_noise(image, threshold=15):
# Convert to grayscale if it's a color image
@ -66,6 +75,8 @@ def preprocess_frame(frame, resize=True, scale=True):
# Check frame dimensions and resize if necessary
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if scale:
# Scale frame to [0, 1]