This commit is contained in:
Jordon Brooks 2023-08-17 01:57:53 +01:00
parent 3ea1568ad3
commit 7787d0584e
4 changed files with 35 additions and 25 deletions

View file

@ -22,7 +22,7 @@ SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
# Load the uncompressed video
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
UNCOMPRESSED_VIDEO_FILE = 'test_data/B4_t02.mkv'
def load_frame_from_video(video_file, frame_num):
cap = cv2.VideoCapture(video_file)
@ -41,16 +41,15 @@ def predict_frame(uncompressed_frame):
scaled_crf = scale_crf(CRF)
scaled_speed = scale_speed_preset(SPEED)
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed)
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
display_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
cv2.imshow("comp", display_frame)
cv2.waitKey(1)
cv2.imshow("comp", compressed_frame)
return compressed_frame
@ -75,9 +74,9 @@ for i in range(total_frames):
compressed_frame = cv2.resize(compressed_frame, (width, height))
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
#compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
#compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
out.write(compressed_frame)

View file

@ -54,11 +54,11 @@ def psnr(y_true, y_pred):
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
def preprocess_frame(frame):
def preprocess_frame(frame, resize=True):
#Preprocesses a single frame, cropping it if needed
# Check frame dimensions and resize if necessary
if frame.shape[:2] != (HEIGHT, WIDTH):
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
# Scale frame to [0, 1]

View file

@ -32,15 +32,20 @@ from video_compression_model import VideoCompressionModel, data_generator
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants
BATCH_SIZE = 16
BATCH_SIZE = 25
EPOCHS = 100
LEARNING_RATE = 0.001
DECAY_STEPS = 40
LEARNING_RATE = 0.0001
DECAY_STEPS = 160
DECAY_RATE = 0.9
MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 5
class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"Collecting garbage")
gc.collect()
def save_model(model):
try:
LOGGER.debug("Attempting to save the model.")
@ -145,6 +150,9 @@ def main():
)
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
# Custom garbage collection callback
gc_callback = GarbageCollectorCallback()
# Calculate steps per epoch for training and validation
if MAX_FRAMES <= 0:
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
@ -164,7 +172,7 @@ def main():
steps_per_epoch=steps_per_epoch_train,
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
validation_steps=steps_per_epoch_validation, # Add validation steps here
callbacks=[early_stop, checkpoint_callback]
callbacks=[early_stop, checkpoint_callback, gc_callback]
)
LOGGER.info("Model training completed.")

View file

@ -13,17 +13,18 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
#policy = Policy('mixed_float16')
#tf.keras.mixed_precision.set_global_policy(policy)
def combine_batch(frame, crf, speed):
# Preprocess the compressed frame (target)
processed_frame = preprocess_frame(frame)
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
processed_frame = preprocess_frame(frame, resize)
height, width, _ = processed_frame.shape
crf_array = np.full((HEIGHT, WIDTH, 1), crf) # Note the added dimension
speed_array = np.full((HEIGHT, WIDTH, 1), speed) # Note the added dimension
combined = [processed_frame]
if include_controls:
crf_array = np.full((height, width, 1), crf)
speed_array = np.full((height, width, 1), speed)
combined.extend([crf_array, speed_array])
# Combine the frames with the CRF and SPEED images
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
return np.concatenate(combined, axis=-1)
return combined
def data_generator(videos, batch_size):
# Infinite loop to keep generating batches
@ -53,8 +54,10 @@ def data_generator(videos, batch_size):
if not ret_compressed or not ret_uncompressed:
break
compressed_combined = combine_batch(compressed_frame, CRF, SPEED)
# Target data
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
# Input data
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
# Append processed frames to batches
@ -80,8 +83,8 @@ class VideoCompressionModel(tf.keras.Model):
super(VideoCompressionModel, self).__init__()
LOGGER.debug("Initializing VideoCompressionModel.")
# Input shape (includes channels for edges and histogram)
input_shape_with_histogram = (HEIGHT, WIDTH, NUM_COLOUR_CHANNELS + 2)
# Input shape (includes channels for CRF and SPEED_PRESET)
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
# Encoder part of the model
self.encoder = tf.keras.Sequential([