diff --git a/train_model.py b/train_model.py index f08b84e..9127ec8 100644 --- a/train_model.py +++ b/train_model.py @@ -55,6 +55,7 @@ def load_video_from_list(list_path): video_details['preset_speed'] = PRESET_SPEED train_frames, w, h = load_frames_from_video(os.path.join("test_data/", VIDEO_FILE), NUM_FRAMES * TRAIN_SAMPLES) + all_frames.extend(train_frames) all_details.append({ "frames": train_frames, @@ -69,7 +70,7 @@ def load_video_from_list(list_path): def generate_frame_sequences(frames): sequences = [] labels = [] - for i in range(len(frames) - NUM_FRAMES + 2): + for i in range(len(frames) - NUM_FRAMES + 1): sequence = frames[i:i+NUM_FRAMES-1] sequences.append(sequence) labels.append(sequence[-1]) @@ -87,39 +88,73 @@ def main(): model = VideoCompressionModel(NUM_CHANNELS, NUM_FRAMES) model.compile(loss='mean_squared_error', optimizer='adam') - early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True) + # Load and concatenate all sequences and labels + all_train_sequences = [] + all_val_sequences = [] + all_train_labels = [] + all_val_labels = [] + all_crf_train = [] + all_crf_val = [] + all_preset_speed_train = [] + all_preset_speed_val = [] + for video_details_train, video_details_val in zip(all_video_details_train, all_video_details_val): train_frames = video_details_train["frames"] val_frames = video_details_val["frames"] train_differences = frame_difference(preprocess(train_frames)) val_differences = frame_difference(preprocess(val_frames)) + + #print(len(train_differences), train_differences[0].shape) train_sequences, train_labels = generate_frame_sequences(train_differences) val_sequences, val_labels = generate_frame_sequences(val_differences) - num_sequences_train = len(train_sequences) - num_sequences_val = len(val_sequences) - crf_array_train = np.full((num_sequences_train, 1), video_details_train['crf']) - crf_array_val = np.full((num_sequences_val, 1), video_details_val['crf']) - preset_speed_array_train = np.full((num_sequences_train, 1), video_details_train['preset_speed']) - preset_speed_array_val = np.full((num_sequences_val, 1), video_details_val['preset_speed']) + crf_array_train = np.full((len(train_sequences), 1), video_details_train['crf']) + crf_array_val = np.full((len(val_sequences), 1), video_details_val['crf']) + preset_speed_array_train = np.full((len(train_sequences), 1), video_details_train['preset_speed']) + preset_speed_array_val = np.full((len(val_sequences), 1), video_details_val['preset_speed']) - print(len(train_sequences)) - print(len(val_sequences)) + all_train_sequences.extend(train_sequences) + all_val_sequences.extend(val_sequences) + all_train_labels.extend(train_labels) + all_val_labels.extend(val_labels) + all_crf_train.extend(crf_array_train) + all_crf_val.extend(crf_array_val) + all_preset_speed_train.extend(preset_speed_array_train) + all_preset_speed_val.extend(preset_speed_array_val) - print("\nTraining the model for video:", video_details_train["video_file"]) - model.fit( - {"frames": train_sequences, "crf": crf_array_train, "preset_speed": preset_speed_array_train}, - train_labels, - batch_size=BATCH_SIZE, - epochs=EPOCHS, - validation_data=({"frames": val_sequences, "crf": crf_array_val, "preset_speed": preset_speed_array_val}, val_labels), - callbacks=[early_stop] - ) - print("\nTraining completed for video:", video_details_train["video_file"]) + # Convert lists to numpy arrays + all_train_sequences = np.array(all_train_sequences) + all_val_sequences = np.array(all_val_sequences) + all_train_labels = np.array(all_train_labels) + all_val_labels = np.array(all_val_labels) + all_crf_train = np.array(all_crf_train) + all_crf_val = np.array(all_crf_val) + all_preset_speed_train = np.array(all_preset_speed_train) + all_preset_speed_val = np.array(all_preset_speed_val) + + # Shuffle the training data + indices_train = np.arange(all_train_sequences.shape[0]) + np.random.shuffle(indices_train) + + all_train_sequences = all_train_sequences[indices_train] + all_train_labels = all_train_labels[indices_train] + all_crf_train = all_crf_train[indices_train] + all_preset_speed_train = all_preset_speed_train[indices_train] + + print("\nTraining the model on mixed sequences...") + model.fit( + {"frames": all_train_sequences, "crf": all_crf_train, "preset_speed": all_preset_speed_train}, + all_train_labels, + batch_size=BATCH_SIZE, + epochs=EPOCHS, + validation_data=({"frames": all_val_sequences, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_labels), + callbacks=[early_stop] + ) + print("\nTraining completed!") save_model(model, 'model_differencing.keras')