Batches
This commit is contained in:
parent
5bca78e687
commit
9167ff27d4
2 changed files with 46 additions and 31 deletions
|
@ -4,16 +4,16 @@ import numpy as np
|
|||
import cv2
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES
|
||||
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES, VideoDataGenerator
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 40
|
||||
LEARNING_RATE = 0.00001
|
||||
TRAIN_SAMPLES = 100
|
||||
BATCH_SIZE = 4
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.000001
|
||||
TRAIN_SAMPLES = 500
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
CONTINUE_TRAINING = None
|
||||
|
@ -23,14 +23,14 @@ def load_list(list_path):
|
|||
video_details_list = json.load(json_file)
|
||||
return video_details_list
|
||||
|
||||
def load_video_from_list(list_path):
|
||||
def load_video_from_list(list_path, samples = TRAIN_SAMPLES):
|
||||
details_list = load_list(list_path)
|
||||
all_details = []
|
||||
|
||||
num_videos = len(details_list)
|
||||
frames_per_video = int(TRAIN_SAMPLES / num_videos)
|
||||
frames_per_video = int(samples / num_videos)
|
||||
|
||||
print(f"Loading {frames_per_video} across {num_videos} videos")
|
||||
print(f"Loading {frames_per_video} frames across {num_videos} videos")
|
||||
|
||||
for video_details in details_list:
|
||||
VIDEO_FILE = video_details["video_file"]
|
||||
|
@ -108,26 +108,10 @@ def main():
|
|||
print(f"Continue training from: {CONTINUE_TRAINING}")
|
||||
|
||||
all_video_details_train = load_video_from_list("test_data/training.json")
|
||||
all_video_details_val = load_video_from_list("test_data/validation.json")
|
||||
all_video_details_val = load_video_from_list("test_data/validation.json", TRAIN_SAMPLES / 2)
|
||||
|
||||
all_train_frames = [video_details["frame"] for video_details in all_video_details_train]
|
||||
all_train_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_train]
|
||||
all_val_frames = [video_details["frame"] for video_details in all_video_details_val]
|
||||
all_val_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_val]
|
||||
all_crf_train = [video_details['crf'] for video_details in all_video_details_train]
|
||||
all_crf_val = [video_details['crf'] for video_details in all_video_details_val]
|
||||
all_preset_speed_train = [video_details['preset_speed'] for video_details in all_video_details_train]
|
||||
all_preset_speed_val = [video_details['preset_speed'] for video_details in all_video_details_val]
|
||||
|
||||
# Convert lists to numpy arrays
|
||||
all_train_frames = np.array(all_train_frames)
|
||||
all_train_compressed_frames = np.array(all_train_compressed_frames)
|
||||
all_val_frames = np.array(all_val_frames)
|
||||
all_val_compressed_frames = np.array(all_val_compressed_frames)
|
||||
all_crf_train = np.array(all_crf_train)
|
||||
all_crf_val = np.array(all_crf_val)
|
||||
all_preset_speed_train = np.array(all_preset_speed_train)
|
||||
all_preset_speed_val = np.array(all_preset_speed_val)
|
||||
train_generator = VideoDataGenerator(all_video_details_train, BATCH_SIZE)
|
||||
val_generator = VideoDataGenerator(all_video_details_val, BATCH_SIZE)
|
||||
|
||||
if CONTINUE_TRAINING:
|
||||
print("loading model:", CONTINUE_TRAINING)
|
||||
|
@ -154,11 +138,11 @@ def main():
|
|||
|
||||
print("\nTraining the model...")
|
||||
model.fit(
|
||||
{"uncompressed_frame": all_train_frames, "compressed_frame": all_train_compressed_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train},
|
||||
all_train_compressed_frames, # Target is the compressed frame
|
||||
batch_size=BATCH_SIZE,
|
||||
train_generator,
|
||||
steps_per_epoch=len(train_generator),
|
||||
epochs=EPOCHS,
|
||||
validation_data=({"uncompressed_frame": all_val_frames, "compressed_frame": all_val_compressed_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_compressed_frames),
|
||||
validation_data=val_generator,
|
||||
validation_steps=len(val_generator),
|
||||
callbacks=[early_stop, checkpoint_callback]
|
||||
)
|
||||
print("\nTraining completed!")
|
||||
|
|
Reference in a new issue