153 lines
5.5 KiB
Python
153 lines
5.5 KiB
Python
import os
|
|
import json
|
|
import numpy as np
|
|
import cv2
|
|
import argparse
|
|
import tensorflow as tf
|
|
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
|
|
|
print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
|
|
|
|
# Constants
|
|
BATCH_SIZE = 4
|
|
EPOCHS = 100
|
|
LEARNING_RATE = 0.000001
|
|
TRAIN_SAMPLES = 500
|
|
MODEL_SAVE_FILE = "models/model.tf"
|
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
|
CONTINUE_TRAINING = None
|
|
|
|
def load_list(list_path):
|
|
with open(list_path, "r") as json_file:
|
|
video_details_list = json.load(json_file)
|
|
return video_details_list
|
|
|
|
def load_video_from_list(list_path, samples = TRAIN_SAMPLES):
|
|
details_list = load_list(list_path)
|
|
all_details = []
|
|
|
|
num_videos = len(details_list)
|
|
frames_per_video = int(samples / num_videos)
|
|
|
|
print(f"Loading {frames_per_video} frames across {num_videos} videos")
|
|
|
|
for video_details in details_list:
|
|
VIDEO_FILE = video_details["video_file"]
|
|
UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"]
|
|
CRF = video_details['crf'] / 63.0
|
|
PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
|
video_details['preset_speed'] = PRESET_SPEED
|
|
|
|
frames = []
|
|
frames_compressed = []
|
|
|
|
cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE))
|
|
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE))
|
|
|
|
for _ in range(frames_per_video):
|
|
ret, frame_compressed = cap.read()
|
|
ret_uncompressed, frame = cap_uncompressed.read()
|
|
|
|
if not ret or not ret_uncompressed:
|
|
continue
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
|
|
|
frames.append(preprocess(frame))
|
|
frames_compressed.append(preprocess(frame_compressed))
|
|
|
|
for uncompressed_frame, compressed_frame in zip(frames, frames_compressed):
|
|
all_details.append({
|
|
"frame": uncompressed_frame,
|
|
"compressed_frame": compressed_frame,
|
|
"crf": CRF,
|
|
"preset_speed": PRESET_SPEED,
|
|
"video_file": VIDEO_FILE
|
|
})
|
|
|
|
cap.release()
|
|
cap_uncompressed.release()
|
|
|
|
return all_details
|
|
|
|
def preprocess(frame):
|
|
return frame / 255.0
|
|
|
|
def save_model(model):
|
|
os.makedirs("models", exist_ok=True)
|
|
model.save(MODEL_SAVE_FILE, save_format='tf')
|
|
print("Model saved successfully!")
|
|
|
|
def main():
|
|
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING
|
|
|
|
# Argument parsing
|
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
|
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
|
parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.')
|
|
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Use the parsed arguments in your script
|
|
BATCH_SIZE = args.batch_size
|
|
EPOCHS = args.epochs
|
|
TRAIN_SAMPLES = args.training_samples
|
|
LEARNING_RATE = args.learning_rate
|
|
CONTINUE_TRAINING = args.continue_training
|
|
|
|
print("Training configuration:")
|
|
print(f"Batch size: {BATCH_SIZE}")
|
|
print(f"Epochs: {EPOCHS}")
|
|
print(f"Training samples: {TRAIN_SAMPLES}")
|
|
print(f"Learning rate: {LEARNING_RATE}")
|
|
print(f"Continue training from: {CONTINUE_TRAINING}")
|
|
|
|
all_video_details_train = load_video_from_list("test_data/training.json")
|
|
all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2)
|
|
|
|
train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE)
|
|
val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE)
|
|
|
|
if CONTINUE_TRAINING:
|
|
print("loading model:", CONTINUE_TRAINING)
|
|
model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file
|
|
else:
|
|
model = VideoCompressionModel()
|
|
|
|
# Define the optimizer with a specific learning rate
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
|
|
|
os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True)
|
|
checkpoint_callback = ModelCheckpoint(
|
|
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
|
save_weights_only=False,
|
|
save_best_only=False,
|
|
verbose=1,
|
|
save_format="tf"
|
|
)
|
|
|
|
#tf.config.run_functions_eagerly(True)
|
|
|
|
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
|
early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True)
|
|
|
|
print("\nTraining the model...")
|
|
model.fit(
|
|
train_generator,
|
|
steps_per_epoch=len(train_generator),
|
|
epochs=EPOCHS,
|
|
validation_data=val_generator,
|
|
validation_steps=len(val_generator),
|
|
callbacks=[early_stop, checkpoint_callback]
|
|
)
|
|
print("\nTraining completed!")
|
|
|
|
save_model(model)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|