update
This commit is contained in:
parent
1d98bc84a2
commit
fde856f3ec
6 changed files with 107 additions and 109 deletions
|
@ -4,12 +4,14 @@ import os
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from train_model_V2 import VideoCompressionModel
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
from global_train import LOGGER
|
||||
from video_compression_model import VideoCompressionModel, data_generator
|
||||
|
||||
from globalVars import HEIGHT, WIDTH, LOGGER
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
|
@ -18,10 +20,6 @@ LEARNING_RATE = 0.01
|
|||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
|
||||
NUM_CHANNELS = 3
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
|
||||
def save_model(model):
|
||||
try:
|
||||
|
@ -33,34 +31,6 @@ def save_model(model):
|
|||
LOGGER.error(f"Error saving the model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def extract_edge_features(frame):
|
||||
"""
|
||||
Extract edge features using Canny edge detection.
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
|
||||
Returns:
|
||||
- ndarray: Edge feature map.
|
||||
"""
|
||||
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
|
||||
return edges.astype(np.float32) / 255.0
|
||||
|
||||
def extract_histogram_features(frame, bins=64):
|
||||
"""
|
||||
Extract histogram features from a frame.
|
||||
|
||||
Args:
|
||||
- frame (ndarray): Image frame.
|
||||
- bins (int): Number of bins for the histogram.
|
||||
|
||||
Returns:
|
||||
- ndarray: Normalized histogram feature vector.
|
||||
"""
|
||||
histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255])
|
||||
return histogram.astype(np.float32) / frame.size
|
||||
|
||||
|
||||
def load_video_metadata(list_path):
|
||||
"""
|
||||
|
@ -85,57 +55,16 @@ def load_video_metadata(list_path):
|
|||
except json.JSONDecodeError:
|
||||
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
||||
raise
|
||||
|
||||
def data_generator(videos, batch_size):
|
||||
while True:
|
||||
for video_details in videos:
|
||||
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
feature_batch = []
|
||||
compressed_frame_batch = []
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if frame.shape[:2] != (HEIGHT, WIDTH):
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
# Extract features
|
||||
edge_feature = extract_edge_features(frame)
|
||||
histogram_feature = extract_histogram_features(frame)
|
||||
histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape
|
||||
combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1)
|
||||
|
||||
compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1]
|
||||
|
||||
feature_batch.append(combined_feature)
|
||||
compressed_frame_batch.append(compressed_frame)
|
||||
|
||||
if len(feature_batch) == batch_size:
|
||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||
feature_batch = []
|
||||
compressed_frame_batch = []
|
||||
|
||||
cap.release()
|
||||
|
||||
# If there are frames left that don't fill a whole batch, send them anyway
|
||||
if len(feature_batch) > 0:
|
||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
|
||||
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
||||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
BATCH_SIZE = args.batch_size
|
||||
|
|
Reference in a new issue