updated
This commit is contained in:
parent
60c6c97071
commit
93ccce5ec1
6 changed files with 46 additions and 34 deletions
|
@ -18,10 +18,12 @@ from global_train import LOGGER
|
|||
BATCH_SIZE = 4
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.000001
|
||||
TRAIN_SAMPLES = 50
|
||||
TRAIN_SAMPLES = 100
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
WIDTH = 638
|
||||
HEIGHT = 360
|
||||
|
||||
def load_video_metadata(list_path):
|
||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||
|
@ -67,11 +69,11 @@ def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
|||
compressed_frames, uncompressed_frames = [], []
|
||||
|
||||
try:
|
||||
cap = cv2.VideoCapture(os.path.join("test_data/", video_file))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", uncompressed_video_file))
|
||||
cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file))
|
||||
|
||||
if not cap.isOpened() or not cap_uncompressed.isOpened():
|
||||
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}")
|
||||
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}")
|
||||
|
||||
for _ in range(frames_per_video):
|
||||
ret, frame_compressed = cap.read()
|
||||
|
@ -79,6 +81,14 @@ def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
|||
|
||||
if not ret or not ret_uncompressed:
|
||||
continue
|
||||
|
||||
# Check frame dimensions and resize if necessary
|
||||
if frame.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.warn(f"Resizing video: {video_file}")
|
||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
if frame_compressed.shape[:2] != (WIDTH, HEIGHT):
|
||||
LOGGER.warn(f"Resizing video: {uncompressed_video_file}")
|
||||
frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
||||
|
@ -149,8 +159,8 @@ def main():
|
|||
|
||||
# Load training and validation samples
|
||||
LOGGER.debug("Loading training and validation samples.")
|
||||
training_samples = load_video_samples("test_data/training.json")
|
||||
validation_samples = load_video_samples("test_data/validation.json", args.training_samples // 2)
|
||||
training_samples = load_video_samples("test_data/training/training.json")
|
||||
validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2)
|
||||
|
||||
train_generator = VideoDataGenerator(training_samples, args.batch_size)
|
||||
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
|
||||
|
|
Reference in a new issue