Working GPU model
This commit is contained in:
parent
5085c87300
commit
dea59068fb
3 changed files with 190 additions and 108 deletions
|
@ -1,15 +1,19 @@
|
|||
# DeepEncode.py
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from video_compression_model import VideoCompressionModel
|
||||
|
||||
# Constants
|
||||
CHUNK_SIZE = 24 # Adjust based on available memory and video resolution
|
||||
COMPRESSED_VIDEO_FILE = 'compressed_video.mp4'
|
||||
MAX_FRAMES = 24 # Limit the number of frames processed
|
||||
CHUNK_SIZE = 10 # Adjust based on available memory and video resolution
|
||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||
CRF = 25.0 # Example CRF value
|
||||
PRESET_SPEED = 4 # Index for "fast" in our defined list
|
||||
|
||||
# Load the trained model
|
||||
model = tf.keras.models.load_model('models/model.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||
model = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||
|
||||
# Load the uncompressed video
|
||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
||||
|
@ -23,46 +27,56 @@ def load_frame_from_video(video_file, frame_num):
|
|||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 # Normalize and convert to float32
|
||||
cap.release()
|
||||
|
||||
#display_frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8)
|
||||
#cv2.imshow("uncomp", display_frame)
|
||||
#cv2.waitKey(0) # Add this line to hold the display window until a key is pressed
|
||||
|
||||
|
||||
return frame
|
||||
|
||||
def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value):
|
||||
crf_array = np.array([crf_value])
|
||||
preset_speed_array = np.array([preset_speed_value])
|
||||
|
||||
# Expand dimensions to include batch size
|
||||
uncompressed_frame = np.expand_dims(uncompressed_frame, 0)
|
||||
|
||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
#cv2.imshow("uncomp", display_frame)
|
||||
#cv2.waitKey(10)
|
||||
|
||||
compressed_frame = model.predict({
|
||||
"frame": np.array([uncompressed_frame]),
|
||||
"uncompressed_frame": uncompressed_frame,
|
||||
"compressed_frame": uncompressed_frame,
|
||||
"crf": crf_array,
|
||||
"preset_speed": preset_speed_array
|
||||
})
|
||||
|
||||
display_frame = np.clip(cv2.cvtColor(compressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
cv2.imshow("comp", display_frame)
|
||||
cv2.waitKey(10)
|
||||
|
||||
return compressed_frame[0]
|
||||
|
||||
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
height, width = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
cap.release()
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, 24.0, (width, height))
|
||||
|
||||
if not out.isOpened():
|
||||
print("Error: VideoWriter could not be opened.")
|
||||
exit()
|
||||
|
||||
if MAX_FRAMES != 0 and total_frames > MAX_FRAMES:
|
||||
total_frames = MAX_FRAMES
|
||||
|
||||
crf_value = 25.0 # Example CRF value
|
||||
preset_speed_value = 2 # Index for "fast" in our defined list
|
||||
|
||||
height, width = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
fourcc = cv2.VideoWriter_fourcc(*'H264')
|
||||
out = cv2.VideoWriter(COMPRESSED_VIDEO_FILE, fourcc, 24.0, (width, height))
|
||||
|
||||
for i in range(total_frames):
|
||||
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
||||
compressed_frame = predict_frame(uncompressed_frame, model, crf_value, preset_speed_value)
|
||||
compressed_frame = predict_frame(uncompressed_frame, model, CRF, PRESET_SPEED)
|
||||
|
||||
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||
out.write(compressed_frame)
|
||||
cv2.imshow("output", compressed_frame)
|
||||
#cv2.imshow("output", compressed_frame)
|
||||
|
||||
out.release()
|
||||
print("Compression completed.")
|
||||
|
|
168
train_model.py
168
train_model.py
|
@ -2,106 +2,168 @@ import os
|
|||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from video_compression_model import NUM_CHANNELS, VideoCompressionModel, PRESET_SPEED_CATEGORIES
|
||||
from tensorflow.keras.callbacks import EarlyStopping
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
print(tf.config.list_physical_devices('GPU'))
|
||||
print("GPUs Detected:", tf.config.list_physical_devices('GPU'))
|
||||
|
||||
# Constants
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 50
|
||||
TRAIN_SAMPLES = 5
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 40
|
||||
LEARNING_RATE = 0.00001
|
||||
TRAIN_SAMPLES = 100
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
CONTINUE_TRAINING = None
|
||||
|
||||
def load_list(list_path):
|
||||
with open(list_path, "r") as json_file:
|
||||
video_details_list = json.load(json_file)
|
||||
return video_details_list
|
||||
|
||||
def load_frame_from_video(video_file):
|
||||
print("Extracting video frame...")
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
return None
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
cap.release()
|
||||
return frame
|
||||
|
||||
def preprocess(frame):
|
||||
return frame / 255.0
|
||||
|
||||
def save_model(model, file):
|
||||
os.makedirs("models", exist_ok=True)
|
||||
model.save(os.path.join("models/", file))
|
||||
print("Model saved successfully!")
|
||||
|
||||
def load_video_from_list(list_path):
|
||||
details_list = load_list(list_path)
|
||||
all_frames = []
|
||||
all_details = []
|
||||
|
||||
num_videos = len(details_list)
|
||||
frames_per_video = int(TRAIN_SAMPLES / num_videos)
|
||||
|
||||
print(f"Loading {frames_per_video} across {num_videos} videos")
|
||||
|
||||
for video_details in details_list:
|
||||
VIDEO_FILE = video_details["video_file"]
|
||||
UNCOMPRESSED_VIDEO_FILE = video_details["uncompressed_video_file"]
|
||||
CRF = video_details['crf'] / 63.0
|
||||
PRESET_SPEED = PRESET_SPEED_CATEGORIES.index(video_details['preset_speed'])
|
||||
video_details['preset_speed'] = PRESET_SPEED
|
||||
|
||||
frame = load_frame_from_video(os.path.join("test_data/", VIDEO_FILE))
|
||||
frames = []
|
||||
frames_compressed = []
|
||||
|
||||
if frame is not None:
|
||||
all_frames.append(preprocess(frame))
|
||||
cap = cv2.VideoCapture(os.path.join("test_data/", VIDEO_FILE))
|
||||
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", UNCOMPRESSED_VIDEO_FILE))
|
||||
|
||||
for _ in range(frames_per_video):
|
||||
ret, frame_compressed = cap.read()
|
||||
ret_uncompressed, frame = cap_uncompressed.read()
|
||||
|
||||
if not ret or not ret_uncompressed:
|
||||
continue
|
||||
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
||||
|
||||
frames.append(preprocess(frame))
|
||||
frames_compressed.append(preprocess(frame_compressed))
|
||||
|
||||
for uncompressed_frame, compressed_frame in zip(frames, frames_compressed):
|
||||
all_details.append({
|
||||
"frame": frame,
|
||||
"frame": uncompressed_frame,
|
||||
"compressed_frame": compressed_frame,
|
||||
"crf": CRF,
|
||||
"preset_speed": PRESET_SPEED,
|
||||
"video_file": VIDEO_FILE
|
||||
})
|
||||
|
||||
cap.release()
|
||||
cap_uncompressed.release()
|
||||
|
||||
return all_details
|
||||
|
||||
def preprocess(frame):
|
||||
return frame / 255.0
|
||||
|
||||
def save_model(model):
|
||||
os.makedirs("models", exist_ok=True)
|
||||
model.save(MODEL_SAVE_FILE, save_format='tf')
|
||||
print("Model saved successfully!")
|
||||
|
||||
def main():
|
||||
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, CONTINUE_TRAINING
|
||||
|
||||
# Argument parsing
|
||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
||||
parser.add_argument('-s', '--training_samples', type=int, default=TRAIN_SAMPLES, help='Number of training samples.')
|
||||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use the parsed arguments in your script
|
||||
BATCH_SIZE = args.batch_size
|
||||
EPOCHS = args.epochs
|
||||
TRAIN_SAMPLES = args.training_samples
|
||||
LEARNING_RATE = args.learning_rate
|
||||
CONTINUE_TRAINING = args.continue_training
|
||||
|
||||
print("Training configuration:")
|
||||
print(f"Batch size: {BATCH_SIZE}")
|
||||
print(f"Epochs: {EPOCHS}")
|
||||
print(f"Training samples: {TRAIN_SAMPLES}")
|
||||
print(f"Learning rate: {LEARNING_RATE}")
|
||||
print(f"Continue training from: {CONTINUE_TRAINING}")
|
||||
|
||||
all_video_details_train = load_video_from_list("test_data/training.json")
|
||||
all_video_details_val = load_video_from_list("test_data/validation.json")
|
||||
|
||||
model = VideoCompressionModel(NUM_CHANNELS)
|
||||
model.compile(loss='mean_squared_error', optimizer='adam')
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, restore_best_weights=True)
|
||||
|
||||
# Prepare data
|
||||
all_train_frames = []
|
||||
all_val_frames = []
|
||||
all_crf_train = []
|
||||
all_crf_val = []
|
||||
all_preset_speed_train = []
|
||||
all_preset_speed_val = []
|
||||
|
||||
for video_details_train, video_details_val in zip(all_video_details_train, all_video_details_val):
|
||||
all_train_frames.append(video_details_train["frame"])
|
||||
all_val_frames.append(video_details_val["frame"])
|
||||
all_crf_train.append(video_details_train['crf'])
|
||||
all_crf_val.append(video_details_val['crf'])
|
||||
all_preset_speed_train.append(video_details_train['preset_speed'])
|
||||
all_preset_speed_val.append(video_details_val['preset_speed'])
|
||||
all_train_frames = [video_details["frame"] for video_details in all_video_details_train]
|
||||
all_train_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_train]
|
||||
all_val_frames = [video_details["frame"] for video_details in all_video_details_val]
|
||||
all_val_compressed_frames = [video_details["compressed_frame"] for video_details in all_video_details_val]
|
||||
all_crf_train = [video_details['crf'] for video_details in all_video_details_train]
|
||||
all_crf_val = [video_details['crf'] for video_details in all_video_details_val]
|
||||
all_preset_speed_train = [video_details['preset_speed'] for video_details in all_video_details_train]
|
||||
all_preset_speed_val = [video_details['preset_speed'] for video_details in all_video_details_val]
|
||||
|
||||
# Convert lists to numpy arrays
|
||||
all_train_frames = np.array(all_train_frames)
|
||||
all_train_compressed_frames = np.array(all_train_compressed_frames)
|
||||
all_val_frames = np.array(all_val_frames)
|
||||
all_val_compressed_frames = np.array(all_val_compressed_frames)
|
||||
all_crf_train = np.array(all_crf_train)
|
||||
all_crf_val = np.array(all_crf_val)
|
||||
all_preset_speed_train = np.array(all_preset_speed_train)
|
||||
all_preset_speed_val = np.array(all_preset_speed_val)
|
||||
|
||||
print("\nTraining the model on frame pairs...")
|
||||
if CONTINUE_TRAINING:
|
||||
print("loading model:", CONTINUE_TRAINING)
|
||||
model = tf.keras.models.load_model(CONTINUE_TRAINING) # Load from the specified file
|
||||
else:
|
||||
model = VideoCompressionModel()
|
||||
|
||||
# Define the optimizer with a specific learning rate
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
|
||||
|
||||
os.makedirs(MODEL_CHECKPOINT_DIR, exist_ok=True)
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(MODEL_CHECKPOINT_DIR, "epoch-{epoch:02d}.tf"),
|
||||
save_weights_only=False,
|
||||
save_best_only=False,
|
||||
verbose=1,
|
||||
save_format="tf"
|
||||
)
|
||||
|
||||
#tf.config.run_functions_eagerly(True)
|
||||
|
||||
model.compile(loss='mean_squared_error', optimizer=optimizer)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, restore_best_weights=True)
|
||||
|
||||
print("\nTraining the model...")
|
||||
model.fit(
|
||||
{"frame": all_train_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train},
|
||||
all_val_frames, # Target is the compressed frame
|
||||
{"uncompressed_frame": all_train_frames, "compressed_frame": all_train_compressed_frames, "crf": all_crf_train, "preset_speed": all_preset_speed_train},
|
||||
all_train_compressed_frames, # Target is the compressed frame
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS,
|
||||
validation_data=({"frame": all_val_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_frames),
|
||||
callbacks=[early_stop]
|
||||
validation_data=({"uncompressed_frame": all_val_frames, "compressed_frame": all_val_compressed_frames, "crf": all_crf_val, "preset_speed": all_preset_speed_val}, all_val_compressed_frames),
|
||||
callbacks=[early_stop, checkpoint_callback]
|
||||
)
|
||||
print("\nTraining completed!")
|
||||
|
||||
save_model(model, 'model.keras')
|
||||
save_model(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,61 +1,67 @@
|
|||
# video_compression_model.py
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
||||
|
||||
#policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||
NUM_CHANNELS = 3
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self, NUM_CHANNELS=3, NUM_FRAMES=5, regularization_factor=1e-4):
|
||||
def __init__(self):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
|
||||
self.NUM_CHANNELS = NUM_CHANNELS
|
||||
# Inputs
|
||||
self.crf_input = tf.keras.layers.InputLayer(name='crf', input_shape=(1,))
|
||||
self.preset_speed_input = tf.keras.layers.InputLayer(name='preset_speed', input_shape=(1,))
|
||||
self.uncompressed_frame_input = tf.keras.layers.InputLayer(name='uncompressed_frame', input_shape=(None, None, NUM_CHANNELS))
|
||||
self.compressed_frame_input = tf.keras.layers.InputLayer(name='compressed_frame', input_shape=(None, None, NUM_CHANNELS))
|
||||
|
||||
# Regularization
|
||||
self.regularizer = tf.keras.regularizers.l2(regularization_factor)
|
||||
|
||||
# Embedding layer for preset_speed
|
||||
self.preset_embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16, embeddings_regularizer=self.regularizer)
|
||||
# Embedding for speed preset and FC layer for CRF and preset speed
|
||||
self.embedding = tf.keras.layers.Embedding(NUM_PRESET_SPEEDS, 16)
|
||||
self.fc = tf.keras.layers.Dense(32, activation='relu')
|
||||
|
||||
# Encoder layers
|
||||
self.encoder = tf.keras.Sequential([
|
||||
tf.keras.layers.ZeroPadding2D(padding=((1, 1), (1, 1))), # Padding to preserve spatial dimensions
|
||||
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(None, None, 2 * NUM_CHANNELS + 32)),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.MaxPooling2D((2, 2)),
|
||||
# Add more encoder layers as needed
|
||||
tf.keras.layers.Dropout(0.3)
|
||||
])
|
||||
|
||||
# Decoder layers
|
||||
self.decoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same', kernel_regularizer=self.regularizer),
|
||||
tf.keras.layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
# Add more decoder layers as needed
|
||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same', kernel_regularizer=self.regularizer), # Output layer for video frames
|
||||
tf.keras.layers.Cropping2D(cropping=((1, 1), (1, 1))) # Adjust cropping to ensure dimensions match
|
||||
|
||||
tf.keras.layers.Dropout(0.3),
|
||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
frame = inputs["frame"]
|
||||
crf = tf.expand_dims(inputs["crf"], -1)
|
||||
preset_speed = inputs["preset_speed"]
|
||||
uncompressed_frame, compressed_frame, crf, preset_speed = inputs['uncompressed_frame'], inputs['compressed_frame'], inputs['crf'], inputs['preset_speed']
|
||||
|
||||
# Convert preset_speed to embeddings
|
||||
preset_embedding = self.preset_embedding(preset_speed)
|
||||
preset_embedding = tf.keras.layers.Flatten()(preset_embedding)
|
||||
# Convert frames to float32
|
||||
uncompressed_frame = tf.cast(uncompressed_frame, tf.float32)
|
||||
compressed_frame = tf.cast(compressed_frame, tf.float32)
|
||||
|
||||
# Concatenate crf and preset_embedding to frames
|
||||
frame_shape = tf.shape(frame)
|
||||
repeated_crf = tf.tile(tf.reshape(crf, (-1, 1, 1, 1)), [1, frame_shape[1], frame_shape[2], 1])
|
||||
repeated_preset = tf.tile(tf.reshape(preset_embedding, (-1, 1, 1, 16)), [1, frame_shape[1], frame_shape[2], 1])
|
||||
# Integrate CRF and preset speed into the network
|
||||
preset_speed_embedded = self.embedding(preset_speed)
|
||||
crf_expanded = tf.expand_dims(crf, -1)
|
||||
integrated_info = tf.keras.layers.Concatenate(axis=-1)([crf_expanded, tf.keras.layers.Flatten()(preset_speed_embedded)])
|
||||
integrated_info = self.fc(integrated_info)
|
||||
|
||||
frame = tf.concat([tf.cast(frame, tf.float32), repeated_crf, repeated_preset], axis=-1)
|
||||
# Integrate the CRF and preset speed information into the frames as additional channels (features)
|
||||
_, height, width, _ = uncompressed_frame.shape
|
||||
integrated_info_repeated = tf.tile(tf.reshape(integrated_info, [-1, 1, 1, 32]), [1, height, width, 1])
|
||||
|
||||
# Encoding the frame
|
||||
compressed_representation = self.encoder(frame)
|
||||
# Merge uncompressed and compressed frames
|
||||
frames_merged = tf.keras.layers.Concatenate(axis=-1)([uncompressed_frame, compressed_frame, integrated_info_repeated])
|
||||
|
||||
# Decoding to generate compressed frame
|
||||
compressed_representation = self.encoder(frames_merged)
|
||||
reconstructed_frame = self.decoder(compressed_representation)
|
||||
|
||||
return reconstructed_frame
|
||||
|
|
Reference in a new issue