model now uses tensorflow dataset generator

This commit is contained in:
Jordon Brooks 2023-08-18 00:42:17 +01:00
parent ba6c132c67
commit f06d3ae504
2 changed files with 84 additions and 48 deletions

View file

@ -27,7 +27,7 @@ if gpus:
print(e) print(e)
from video_compression_model import VideoCompressionModel, data_generator from video_compression_model import VideoCompressionModel, create_dataset
from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
@ -43,7 +43,7 @@ EARLY_STOP = 5
class GarbageCollectorCallback(Callback): class GarbageCollectorCallback(Callback):
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
LOGGER.debug(f"Collecting garbage") LOGGER.debug(f"GC")
gc.collect() gc.collect()
def save_model(model): def save_model(model):
@ -120,6 +120,10 @@ def main():
split_index = int(0.8 * len(all_videos)) split_index = int(0.8 * len(all_videos))
training_videos = all_videos[:split_index] training_videos = all_videos[:split_index]
validation_videos = all_videos[split_index:] validation_videos = all_videos[split_index:]
training_dataset = create_dataset(training_videos, BATCH_SIZE, MAX_FRAMES)
validation_dataset = create_dataset(validation_videos, BATCH_SIZE, MAX_FRAMES)
if args.continue_training: if args.continue_training:
model = tf.keras.models.load_model(args.continue_training) model = tf.keras.models.load_model(args.continue_training)
@ -154,26 +158,24 @@ def main():
gc_callback = GarbageCollectorCallback() gc_callback = GarbageCollectorCallback()
# Calculate steps per epoch for training and validation # Calculate steps per epoch for training and validation
if MAX_FRAMES <= 0: #if MAX_FRAMES <= 0:
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps # average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
else: #else:
average_frames_per_video = max(MAX_FRAMES, 0) # average_frames_per_video = max(MAX_FRAMES, 0)
total_frames_train = average_frames_per_video * len(training_videos) #total_frames_train = average_frames_per_video * len(training_videos)
total_frames_validation = average_frames_per_video * len(validation_videos) #total_frames_validation = average_frames_per_video * len(validation_videos)
steps_per_epoch_train = total_frames_train // BATCH_SIZE #steps_per_epoch_train = total_frames_train // BATCH_SIZE
steps_per_epoch_validation = total_frames_validation // BATCH_SIZE #steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
gc.collect() gc.collect()
# Train the model # Train the model
LOGGER.info("Starting model training.") LOGGER.info("Starting model training.")
model.fit( model.fit(
data_generator(training_videos, BATCH_SIZE), training_dataset,
epochs=EPOCHS, epochs=EPOCHS,
steps_per_epoch=steps_per_epoch_train, validation_data=validation_dataset, # Add validation data here
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
validation_steps=steps_per_epoch_validation, # Add validation steps here
callbacks=[early_stop, checkpoint_callback, gc_callback] callbacks=[early_stop, checkpoint_callback, gc_callback]
) )
LOGGER.info("Model training completed.") LOGGER.info("Model training completed.")

View file

@ -28,52 +28,86 @@ def combine_batch(frame, crf, speed, include_controls=True, resize=True):
return np.concatenate(combined, axis=-1) return np.concatenate(combined, axis=-1)
def data_generator(videos, batch_size): def process_video(video):
base_dir = os.path.dirname("test_data/validation/validation.json") base_dir = os.path.dirname("test_data/validation/validation.json")
cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"]))
cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"]))
compressed_frames = []
uncompressed_frames = []
while True: while True:
# Lists to store the processed frames ret_compressed, compressed_frame = cap_compressed.read()
compressed_frame_batch = [] # Input data (Target) ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
uncompressed_frame_batch = [] # Target data (Training)
# Get a list of video capture objects for all videos if not ret_compressed or not ret_uncompressed:
caps_compressed = [cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"])) for video in videos] break
caps_uncompressed = [cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"])) for video in videos]
# As long as any video can provide frames, keep running CRF = scale_crf(video["crf"])
while any(cap.isOpened() for cap in caps_compressed): SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
for idx, (cap_compressed, cap_uncompressed) in enumerate(zip(caps_compressed, caps_uncompressed)):
#print(f"(Video Change) Processing video {idx}") # Print statement to indicate video change
ret_compressed, compressed_frame = cap_compressed.read()
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
if not ret_compressed or not ret_uncompressed: compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
continue uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
CRF = scale_crf(videos[idx]["crf"]) compressed_frames.append(compressed_combined)
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(videos[idx]["preset_speed"])) uncompressed_frames.append(uncompressed_combined)
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False) cap_compressed.release()
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow"))) cap_uncompressed.release()
compressed_frame_batch.append(compressed_combined) return uncompressed_frames, compressed_frames
uncompressed_frame_batch.append(uncompressed_combined)
if len(compressed_frame_batch) == batch_size:
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
compressed_frame_batch.clear()
uncompressed_frame_batch.clear()
# Close all video captures at the end def frame_generator(videos, max_frames=None):
for cap in caps_compressed + caps_uncompressed: base_dir = "test_data/validation/"
cap.release() for video in videos:
cap_compressed = cv2.VideoCapture(os.path.join(base_dir, video["compressed_video_file"]))
cap_uncompressed = cv2.VideoCapture(os.path.join(base_dir, video["original_video_file"]))
cv2.destroyAllWindows() frame_count = 0
while True:
ret_compressed, compressed_frame = cap_compressed.read()
ret_uncompressed, uncompressed_frame = cap_uncompressed.read()
if not ret_compressed or not ret_uncompressed:
break
CRF = scale_crf(video["crf"])
SPEED = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(video["preset_speed"]))
compressed_combined = combine_batch(compressed_frame, CRF, SPEED, include_controls=False)
uncompressed_combined = combine_batch(uncompressed_frame, 0, scale_speed_preset(PRESET_SPEED_CATEGORIES.index("veryslow")))
yield uncompressed_combined, compressed_combined
frame_count += 1
if max_frames is not None and frame_count >= max_frames:
break
cap_compressed.release()
cap_uncompressed.release()
def create_dataset(videos, batch_size, max_frames=None):
# Determine the output signature by processing a single video to obtain its shape
video_generator_instance = frame_generator(videos, max_frames)
sample_uncompressed, sample_compressed = next(video_generator_instance)
output_signature = (
tf.TensorSpec(shape=tf.shape(sample_uncompressed), dtype=tf.float32),
tf.TensorSpec(shape=tf.shape(sample_compressed), dtype=tf.float32)
)
dataset = tf.data.Dataset.from_generator(
lambda: frame_generator(videos, max_frames), # Include max_frames argument through lambda
output_signature=output_signature
)
dataset = dataset.shuffle(100).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return dataset
# If there are frames left that don't fill a whole batch, send them anyway
if len(compressed_frame_batch) > 0:
yield (np.array(uncompressed_frame_batch), np.array(compressed_frame_batch))
class VideoCompressionModel(tf.keras.Model): class VideoCompressionModel(tf.keras.Model):