209 lines
No EOL
7.4 KiB
Python
209 lines
No EOL
7.4 KiB
Python
import json
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from train_model_V2 import VideoCompressionModel
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
|
import tensorflow as tf
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
|
|
|
from global_train import LOGGER
|
|
|
|
# Constants
|
|
BATCH_SIZE = 16
|
|
EPOCHS = 5
|
|
LEARNING_RATE = 0.01
|
|
MODEL_SAVE_FILE = "models/model.tf"
|
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
|
EARLY_STOP = 10
|
|
|
|
NUM_CHANNELS = 3
|
|
WIDTH = 640
|
|
HEIGHT = 360
|
|
|
|
def save_model(model):
|
|
try:
|
|
LOGGER.debug("Attempting to save the model.")
|
|
os.makedirs("models", exist_ok=True)
|
|
model.save(MODEL_SAVE_FILE, save_format='tf')
|
|
LOGGER.info("Model saved successfully!")
|
|
except Exception as e:
|
|
LOGGER.error(f"Error saving the model: {e}")
|
|
raise
|
|
|
|
|
|
def extract_edge_features(frame):
|
|
"""
|
|
Extract edge features using Canny edge detection.
|
|
|
|
Args:
|
|
- frame (ndarray): Image frame.
|
|
|
|
Returns:
|
|
- ndarray: Edge feature map.
|
|
"""
|
|
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
|
|
return edges.astype(np.float32) / 255.0
|
|
|
|
def extract_histogram_features(frame, bins=64):
|
|
"""
|
|
Extract histogram features from a frame.
|
|
|
|
Args:
|
|
- frame (ndarray): Image frame.
|
|
- bins (int): Number of bins for the histogram.
|
|
|
|
Returns:
|
|
- ndarray: Normalized histogram feature vector.
|
|
"""
|
|
histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255])
|
|
return histogram.astype(np.float32) / frame.size
|
|
|
|
|
|
def load_video_metadata(list_path):
|
|
"""
|
|
Load video metadata from a JSON file.
|
|
|
|
Args:
|
|
- json_path (str): Path to the JSON file containing video metadata.
|
|
|
|
Returns:
|
|
- list: List of dictionaries, each containing video details.
|
|
"""
|
|
|
|
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
|
try:
|
|
with open(list_path, "r") as json_file:
|
|
file = json.load(json_file)
|
|
LOGGER.trace(f"load_video_metadata returning: {file}")
|
|
return file
|
|
except FileNotFoundError:
|
|
LOGGER.error(f"Metadata file {list_path} not found.")
|
|
raise
|
|
except json.JSONDecodeError:
|
|
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
|
raise
|
|
|
|
def data_generator(videos, batch_size):
|
|
while True:
|
|
for video_details in videos:
|
|
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
feature_batch = []
|
|
compressed_frame_batch = []
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
# Check frame dimensions and resize if necessary
|
|
if frame.shape[:2] != (HEIGHT, WIDTH):
|
|
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
|
|
|
|
# Extract features
|
|
edge_feature = extract_edge_features(frame)
|
|
histogram_feature = extract_histogram_features(frame)
|
|
histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape
|
|
combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1)
|
|
|
|
compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1]
|
|
|
|
feature_batch.append(combined_feature)
|
|
compressed_frame_batch.append(compressed_frame)
|
|
|
|
if len(feature_batch) == batch_size:
|
|
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
|
feature_batch = []
|
|
compressed_frame_batch = []
|
|
|
|
cap.release()
|
|
|
|
# If there are frames left that don't fill a whole batch, send them anyway
|
|
if len(feature_batch) > 0:
|
|
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
|
|
|
|
|
def main():
|
|
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
|
|
# Argument parsing
|
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
|
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
|
parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.')
|
|
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
|
|
|
args = parser.parse_args()
|
|
|
|
BATCH_SIZE = args.batch_size
|
|
EPOCHS = args.epochs
|
|
TRAIN_SAMPLES = args.training_samples
|
|
LEARNING_RATE = args.learning_rate
|
|
|
|
# Display training configuration
|
|
LOGGER.info("Starting the training with the given configuration.")
|
|
LOGGER.info("Training configuration:")
|
|
LOGGER.info(f"Batch size: {BATCH_SIZE}")
|
|
LOGGER.info(f"Epochs: {EPOCHS}")
|
|
LOGGER.info(f"Training samples: {TRAIN_SAMPLES}")
|
|
LOGGER.info(f"Learning rate: {LEARNING_RATE}")
|
|
LOGGER.info(f"Continue training from: {MODEL_SAVE_FILE}")
|
|
|
|
LOGGER.debug(f"Max video resolution: {WIDTH}x{HEIGHT}")
|
|
LOGGER.trace("Hello, World!")
|
|
|
|
# Load all video metadata
|
|
all_videos = load_video_metadata("test_data/validation/validation.json")
|
|
|
|
# Split into training and validation
|
|
split_index = int(0.8 * len(all_videos))
|
|
training_videos = all_videos[:split_index]
|
|
validation_videos = all_videos[split_index:]
|
|
|
|
model = VideoCompressionModel()
|
|
|
|
# Set optimizer and compile the model
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
|
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
|
|
|
# Define checkpoints and early stopping
|
|
checkpoint_callback = ModelCheckpoint(
|
|
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
|
save_weights_only=False,
|
|
save_best_only=False,
|
|
verbose=1,
|
|
save_format="tf"
|
|
)
|
|
early_stop = EarlyStopping(monitor='val_loss', patience=EARLY_STOP, verbose=1, restore_best_weights=True)
|
|
|
|
# Calculate steps per epoch for training and validation
|
|
average_frames_per_video = 2880 # Given 2 minutes @ 24 fps
|
|
total_frames_train = average_frames_per_video * len(training_videos)
|
|
total_frames_validation = average_frames_per_video * len(validation_videos)
|
|
steps_per_epoch_train = total_frames_train // BATCH_SIZE
|
|
steps_per_epoch_validation = total_frames_validation // BATCH_SIZE
|
|
|
|
# Train the model
|
|
LOGGER.info("Starting model training.")
|
|
model.fit(
|
|
data_generator(training_videos, BATCH_SIZE),
|
|
epochs=EPOCHS,
|
|
steps_per_epoch=steps_per_epoch_train,
|
|
validation_data=data_generator(validation_videos, BATCH_SIZE), # Add validation data here
|
|
validation_steps=steps_per_epoch_validation, # Add validation steps here
|
|
callbacks=[early_stop, checkpoint_callback]
|
|
)
|
|
LOGGER.info("Model training completed.")
|
|
|
|
save_model(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
LOGGER.error(f"Unexpected error during training: {e}")
|
|
raise |