update
This commit is contained in:
parent
3ea1568ad3
commit
7787d0584e
4 changed files with 35 additions and 25 deletions
|
@ -22,7 +22,7 @@ SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
|||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr})
|
||||
|
||||
# Load the uncompressed video
|
||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/B4_t02.mkv'
|
||||
|
||||
def load_frame_from_video(video_file, frame_num):
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
|
@ -41,16 +41,15 @@ def predict_frame(uncompressed_frame):
|
|||
scaled_crf = scale_crf(CRF)
|
||||
scaled_speed = scale_speed_preset(SPEED)
|
||||
|
||||
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed)
|
||||
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
|
||||
|
||||
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
|
||||
|
||||
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
|
||||
|
||||
display_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
cv2.imshow("comp", display_frame)
|
||||
cv2.waitKey(1)
|
||||
cv2.imshow("comp", compressed_frame)
|
||||
|
||||
return compressed_frame
|
||||
|
||||
|
@ -75,9 +74,9 @@ for i in range(total_frames):
|
|||
|
||||
compressed_frame = cv2.resize(compressed_frame, (width, height))
|
||||
|
||||
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
#compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||
#compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
out.write(compressed_frame)
|
||||
|
||||
|
|
|
@ -54,11 +54,11 @@ def psnr(y_true, y_pred):
|
|||
return 10.0 * K.log((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) / K.log(10.0)
|
||||
|
||||
|
||||
def preprocess_frame(frame):
|
||||
def preprocess_frame(frame, resize=True):
|
||||
#Preprocesses a single frame, cropping it if needed
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if frame.shape[:2] != (HEIGHT, WIDTH):
|
||||
if resize and frame.shape[:2] != (HEIGHT, WIDTH):
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Scale frame to [0, 1]
|
||||
|
|
|
@ -32,15 +32,20 @@ from video_compression_model import VideoCompressionModel, data_generator
|
|||
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
BATCH_SIZE = 25
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.001
|
||||
DECAY_STEPS = 40
|
||||
LEARNING_RATE = 0.0001
|
||||
DECAY_STEPS = 160
|
||||
DECAY_RATE = 0.9
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 5
|
||||
|
||||
class GarbageCollectorCallback(Callback):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
LOGGER.debug(f"Collecting garbage")
|
||||
gc.collect()
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
LOGGER.debug("Attempting to save the model.")
|
||||
|
@ -145,6 +150,9 @@ def main():
|
|||
)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Custom garbage collection callback
|
||||
gc_callback = GarbageCollectorCallback()
|
||||
|
||||
# Calculate steps per epoch for training and validation
|
||||
if MAX_FRAMES <= 0:
|
||||
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
||||
|
@ -164,7 +172,7 @@ def main():
|
|||
steps_per_epoch=steps_per_epoch_train,
|
||||
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
|
||||
validation_steps=steps_per_epoch_validation, # Add validation steps here
|
||||
callbacks=[early_stop, checkpoint_callback]
|
||||
callbacks=[early_stop, checkpoint_callback, gc_callback]
|
||||
)
|
||||
LOGGER.info("Model training completed.")
|
||||
|
||||
|
|
|
@ -13,17 +13,18 @@ from globalVars import HEIGHT, LOGGER, NUM_COLOUR_CHANNELS, NUM_PRESET_SPEEDS, P
|
|||
#policy = Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
|
||||
def combine_batch(frame, crf, speed):
|
||||
# Preprocess the compressed frame (target)
|
||||
processed_frame = preprocess_frame(frame)
|
||||
def combine_batch(frame, crf, speed, include_controls=True, resize=True):
|
||||
processed_frame = preprocess_frame(frame, resize)
|
||||
height, width, _ = processed_frame.shape
|
||||
|
||||
crf_array = np.full((HEIGHT, WIDTH, 1), crf) # Note the added dimension
|
||||
speed_array = np.full((HEIGHT, WIDTH, 1), speed) # Note the added dimension
|
||||
combined = [processed_frame]
|
||||
if include_controls:
|
||||
crf_array = np.full((height, width, 1), crf)
|
||||
speed_array = np.full((height, width, 1), speed)
|
||||
combined.extend([crf_array, speed_array])
|
||||
|
||||
# Combine the frames with the CRF and SPEED images
|
||||
combined = np.concatenate([processed_frame, crf_array, speed_array], axis=-1)
|
||||
return np.concatenate(combined, axis=-1)
|
||||
|
||||
return combined
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
# Infinite loop to keep generating batches
|
||||
|
@ -53,8 +54,10 @@ def data_generator(videos, batch_size):
|
|||
if not ret_compressed or not ret_uncompressed:
|
||||
break
|
||||
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED)
|
||||
# Target data
|
||||
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
|
||||
|
||||
# Input data
|
||||
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
|
||||
|
||||
# Append processed frames to batches
|
||||
|
@ -80,8 +83,8 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
super(VideoCompressionModel, self).__init__()
|
||||
LOGGER.debug("Initializing VideoCompressionModel.")
|
||||
|
||||
# Input shape (includes channels for edges and histogram)
|
||||
input_shape_with_histogram = (HEIGHT, WIDTH, NUM_COLOUR_CHANNELS + 2)
|
||||
# Input shape (includes channels for CRF and SPEED_PRESET)
|
||||
input_shape_with_histogram = (None, None, NUM_COLOUR_CHANNELS + 2)
|
||||
|
||||
# Encoder part of the model
|
||||
self.encoder = tf.keras.Sequential([
|
||||
|
|
Reference in a new issue