This commit is contained in:
Jordon Brooks 2023-09-10 19:05:52 +01:00
parent 4d29fffba1
commit 8df4df7972
No known key found for this signature in database
GPG key ID: 83964894E5D98D57
3 changed files with 51 additions and 20 deletions

View file

@ -7,13 +7,13 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
import tensorflow as tf import tensorflow as tf
from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, scale_crf, scale_speed_preset, ssim
from globalVars import PRESET_SPEED_CATEGORIES, clear_screen from globalVars import PRESET_SPEED_CATEGORIES, clear_screen
from video_compression_model import VideoCompressionModel, combine_batch from video_compression_model import VideoCompressionModel, combine_batch
# Constants # Constants
COMPRESSED_VIDEO_FILE = 'compressed_video.avi' COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 200 # Limit the number of frames processed MAX_FRAMES = 0 # Limit the number of frames processed
CRF = 10 CRF = 10
SPEED = "ultrafast" SPEED = "ultrafast"
MODEL_PATH = 'models/model.tf' MODEL_PATH = 'models/model.tf'
@ -33,7 +33,7 @@ def parse_arguments():
parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model') parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model')
parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file') parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file')
parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen') parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen')
parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False) parser.add_argument('--keep_black_bars', action='store_false', help='Keep black bars from the video', default=True)
args = parser.parse_args() args = parser.parse_args()
@ -95,18 +95,35 @@ def load_frame_from_video(video_file, frame_num):
def predict_frame(uncompressed_frame, model, crf, speed): def predict_frame(uncompressed_frame, model, crf, speed):
# Scale the CRF and Speed values
scaled_crf = scale_crf(crf) scaled_crf = scale_crf(crf)
scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed)) scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed))
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0] # Preprocess the frame
return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8) frame = combine_batch(uncompressed_frame, resize=False)
# Predict using the model
inputs = {
'image': np.expand_dims(frame, axis=0),
'CRF': np.array([scaled_crf]),
'Speed': np.array([scaled_speed])
}
compressed_frame = model.predict(inputs)[0]
# Post-process the output frame
return np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
def main(): def main():
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss}) model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss, 'combined_loss_weighted_psnr': combined_loss_weighted_psnr})
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE) cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES) if MAX_FRAMES > 0:
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
else:
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS)) height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS))
cap.release() cap.release()
@ -127,6 +144,7 @@ def main():
compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED) compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED)
compressed_frame = cv2.resize(compressed_frame, (width, height)) compressed_frame = cv2.resize(compressed_frame, (width, height))
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
out.write(compressed_frame) out.write(compressed_frame)

View file

@ -42,8 +42,8 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_vid
# Constants # Constants
BATCH_SIZE = 25 BATCH_SIZE = 25
EPOCHS = 100 EPOCHS = 1000
LEARNING_RATE = 0.0001 LEARNING_RATE = 0.005
DECAY_STEPS = 160 DECAY_STEPS = 160
DECAY_RATE = 0.9 DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf" MODEL_SAVE_FILE = "models/model.tf"
@ -66,8 +66,10 @@ class ImageLoggingCallback(Callback):
return np.stack(converted, axis=0) return np.stack(converted, axis=0)
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
# where total_batches is the number of batches in the validation dataset
skip_batches = np.random.randint(0, 100)
# Get the first batch from the validation dataset # Get the first batch from the validation dataset
validation_data = next(iter(self.validation_dataset.take(1))) validation_data = next(iter(self.validation_dataset.skip(skip_batches).take(1)))
# Extract the inputs from the batch_input_images dictionary # Extract the inputs from the batch_input_images dictionary
actual_images = validation_data[0]['image'] actual_images = validation_data[0]['image']
@ -82,7 +84,7 @@ class ImageLoggingCallback(Callback):
# Save the reconstructed frame to the specified folder # Save the reconstructed frame to the specified folder
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png") reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example cv2.imwrite(reconstructed_path, cv2.cvtColor(reconstructed_frame[0], cv2.COLOR_RGB2BGR)) # Saving only the first image as an example
# Log images to TensorBoard # Log images to TensorBoard
with self.writer.as_default(): with self.writer.as_default():
@ -145,13 +147,13 @@ def main():
# Load all video metadata # Load all video metadata
all_videos = load_video_metadata("test_data/validation/validation.json") all_videos = load_video_metadata("test_data/validation/validation.json")
tf.random.set_seed(RANDOM_SEED) #tf.random.set_seed(RANDOM_SEED)
# Shuffle the data using the specified seed # Shuffle the data using the specified seed
random.shuffle(all_videos, random.seed(RANDOM_SEED)) random.shuffle(all_videos, random.seed(RANDOM_SEED))
# Split into training and validation # Split into training and validation
split_index = int(0.6 * len(all_videos)) split_index = int(0.7 * len(all_videos))
training_videos = all_videos[:split_index] training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:] validation_videos = all_videos[split_index:]
@ -166,7 +168,14 @@ def main():
if args.continue_training: if args.continue_training:
MODEL = tf.keras.models.load_model(args.continue_training) MODEL = tf.keras.models.load_model(args.continue_training, custom_objects={
'VideoCompressionModel': VideoCompressionModel,
'psnr': psnr,
'ssim': ssim,
'combined': combined,
'combined_loss': combined_loss,
'combined_loss_weighted_psnr': combined_loss_weighted_psnr
})
else: else:
MODEL = VideoCompressionModel() MODEL = VideoCompressionModel()

View file

@ -78,7 +78,7 @@ def create_dataset(videos, batch_size, max_frames=None):
output_signature=output_signature output_signature=output_signature
) )
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1) dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
return dataset return dataset
@ -160,13 +160,16 @@ class VideoCompressionModel(tf.keras.Model):
# New shape: [batch_size, 1, 1, 128] # New shape: [batch_size, 1, 1, 128]
crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128]) crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128])
# Tile the tensor to match spatial dimensions of encoded tensor
# Tiled shape: [batch_size, 90, 160, 128]
crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1])
# Pass the image through the encoder # Pass the image through the encoder
encoded = self.encoder(image) encoded = self.encoder(image)
# Dynamically compute the spatial dimensions of the encoded tensor
encoded_shape = tf.shape(encoded)
height, width = encoded_shape[1], encoded_shape[2]
# Tile the crf_speed_features tensor to match the spatial dimensions of the encoded tensor
crf_speed_features = tf.tile(crf_speed_features, [1, height, width, 1])
# Concatenate the encoded tensor with the crf_speed_features tensor # Concatenate the encoded tensor with the crf_speed_features tensor
combined_features = tf.concat([encoded, crf_speed_features], axis=-1) combined_features = tf.concat([encoded, crf_speed_features], axis=-1)
@ -176,3 +179,4 @@ class VideoCompressionModel(tf.keras.Model):
return decoded return decoded