update
This commit is contained in:
parent
1d98bc84a2
commit
fde856f3ec
6 changed files with 107 additions and 109 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -12,5 +12,7 @@
|
||||||
!video_compression_model.py
|
!video_compression_model.py
|
||||||
!global_train.py
|
!global_train.py
|
||||||
!log.py
|
!log.py
|
||||||
|
!featureExtraction.py
|
||||||
|
!globalVars.py
|
||||||
!test_data/training/training.json
|
!test_data/training/training.json
|
||||||
!test_data/validation/validation.json
|
!test_data/validation/validation.json
|
|
@ -2,22 +2,22 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from featureExtraction import preprocess_frame
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from video_compression_model import PRESET_SPEED_CATEGORIES, VideoCompressionModel
|
from video_compression_model import VideoCompressionModel
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
CHUNK_SIZE = 24 # Adjust based on available memory and video resolution
|
CHUNK_SIZE = 24 # Adjust based on available memory and video resolution
|
||||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||||
CRF = 24.0 # Example CRF value
|
|
||||||
PRESET_SPEED = "veryslow" # Index for "fast" in our defined list
|
|
||||||
|
|
||||||
# Load the trained model
|
# Load the trained model
|
||||||
model = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||||
|
|
||||||
# Load the uncompressed video
|
# Load the uncompressed video
|
||||||
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
UNCOMPRESSED_VIDEO_FILE = 'test_data/training_video.mkv'
|
||||||
|
@ -28,39 +28,27 @@ def load_frame_from_video(video_file, frame_num):
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
return None
|
return None
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 # Normalize and convert to float32
|
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value):
|
def predict_frame(uncompressed_frame):
|
||||||
crf_array = np.array([crf_value])
|
|
||||||
preset_speed_array = np.array([preset_speed_value])
|
|
||||||
|
|
||||||
crf_array = np.expand_dims(np.array([crf_value]), axis=-1) # Shape: (1, 1)
|
|
||||||
preset_speed_array = np.expand_dims(np.array([preset_speed_value]), axis=-1) # Shape: (1, 1)
|
|
||||||
|
|
||||||
|
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
cv2.imshow("uncomp", uncompressed_frame)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
# Expand dimensions to include batch size
|
combined_feature, _ = preprocess_frame(uncompressed_frame)
|
||||||
uncompressed_frame = np.expand_dims(uncompressed_frame, 0)
|
|
||||||
|
|
||||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
|
||||||
#cv2.imshow("uncomp", display_frame)
|
|
||||||
#cv2.waitKey(0)
|
|
||||||
|
|
||||||
compressed_frame = model.predict({
|
compressed_frame = MODEL.predict(np.expand_dims(combined_feature, axis=0))[0]
|
||||||
"compressed_frame": uncompressed_frame,
|
|
||||||
"uncompressed_frame": uncompressed_frame,
|
|
||||||
"crf": crf_array,
|
|
||||||
"preset_speed": preset_speed_array
|
|
||||||
})
|
|
||||||
|
|
||||||
display_frame = np.clip(cv2.cvtColor(compressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
display_frame = np.clip(cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
cv2.imshow("comp", display_frame)
|
cv2.imshow("comp", display_frame)
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
|
|
||||||
return compressed_frame[0]
|
return compressed_frame
|
||||||
|
|
||||||
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
|
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
@ -79,7 +67,7 @@ if MAX_FRAMES != 0 and total_frames > MAX_FRAMES:
|
||||||
|
|
||||||
for i in range(total_frames):
|
for i in range(total_frames):
|
||||||
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
||||||
compressed_frame = predict_frame(uncompressed_frame, model, CRF, PRESET_SPEED_CATEGORIES.index(PRESET_SPEED))
|
compressed_frame = predict_frame(uncompressed_frame)
|
||||||
|
|
||||||
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||||
|
|
48
featureExtraction.py
Normal file
48
featureExtraction.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# featureExtraction.py
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from globalVars import HEIGHT, WIDTH
|
||||||
|
|
||||||
|
|
||||||
|
def extract_edge_features(frame):
|
||||||
|
"""
|
||||||
|
Extract edge features using Canny edge detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- frame (ndarray): Image frame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ndarray: Edge feature map.
|
||||||
|
"""
|
||||||
|
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
|
||||||
|
return edges.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
def extract_histogram_features(frame, bins=64):
|
||||||
|
"""
|
||||||
|
Extract histogram features from a frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- frame (ndarray): Image frame.
|
||||||
|
- bins (int): Number of bins for the histogram.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ndarray: Normalized histogram feature vector.
|
||||||
|
"""
|
||||||
|
histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255])
|
||||||
|
return histogram.astype(np.float32) / frame.size
|
||||||
|
|
||||||
|
def preprocess_frame(frame):
|
||||||
|
# Check frame dimensions and resize if necessary
|
||||||
|
if frame.shape[:2] != (HEIGHT, WIDTH):
|
||||||
|
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
# Extract features
|
||||||
|
edge_feature = extract_edge_features(frame)
|
||||||
|
histogram_feature = extract_histogram_features(frame)
|
||||||
|
histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape
|
||||||
|
combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1)
|
||||||
|
|
||||||
|
compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1]
|
||||||
|
return combined_feature, compressed_frame
|
|
@ -1,3 +1,7 @@
|
||||||
import log
|
import log
|
||||||
|
|
||||||
LOGGER = log.Logger(level="DEBUG", logfile="training.log", reset_logfile=True)
|
LOGGER = log.Logger(level="DEBUG", logfile="training.log", reset_logfile=True)
|
||||||
|
|
||||||
|
NUM_CHANNELS = 3
|
||||||
|
WIDTH = 640
|
||||||
|
HEIGHT = 360
|
|
@ -4,12 +4,14 @@ import os
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from train_model_V2 import VideoCompressionModel
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
|
||||||
from global_train import LOGGER
|
from video_compression_model import VideoCompressionModel, data_generator
|
||||||
|
|
||||||
|
from globalVars import HEIGHT, WIDTH, LOGGER
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
|
@ -18,10 +20,6 @@ LEARNING_RATE = 0.01
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 10
|
||||||
|
|
||||||
NUM_CHANNELS = 3
|
|
||||||
WIDTH = 640
|
|
||||||
HEIGHT = 360
|
|
||||||
|
|
||||||
def save_model(model):
|
def save_model(model):
|
||||||
try:
|
try:
|
||||||
|
@ -33,34 +31,6 @@ def save_model(model):
|
||||||
LOGGER.error(f"Error saving the model: {e}")
|
LOGGER.error(f"Error saving the model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def extract_edge_features(frame):
|
|
||||||
"""
|
|
||||||
Extract edge features using Canny edge detection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- frame (ndarray): Image frame.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- ndarray: Edge feature map.
|
|
||||||
"""
|
|
||||||
edges = cv2.Canny(frame, threshold1=100, threshold2=200)
|
|
||||||
return edges.astype(np.float32) / 255.0
|
|
||||||
|
|
||||||
def extract_histogram_features(frame, bins=64):
|
|
||||||
"""
|
|
||||||
Extract histogram features from a frame.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- frame (ndarray): Image frame.
|
|
||||||
- bins (int): Number of bins for the histogram.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- ndarray: Normalized histogram feature vector.
|
|
||||||
"""
|
|
||||||
histogram, _ = np.histogram(frame.flatten(), bins=bins, range=[0, 255])
|
|
||||||
return histogram.astype(np.float32) / frame.size
|
|
||||||
|
|
||||||
|
|
||||||
def load_video_metadata(list_path):
|
def load_video_metadata(list_path):
|
||||||
"""
|
"""
|
||||||
|
@ -85,57 +55,16 @@ def load_video_metadata(list_path):
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
LOGGER.error(f"Error decoding JSON from {list_path}.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def data_generator(videos, batch_size):
|
|
||||||
while True:
|
|
||||||
for video_details in videos:
|
|
||||||
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
|
||||||
cap = cv2.VideoCapture(video_path)
|
|
||||||
|
|
||||||
feature_batch = []
|
|
||||||
compressed_frame_batch = []
|
|
||||||
|
|
||||||
while cap.isOpened():
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check frame dimensions and resize if necessary
|
|
||||||
if frame.shape[:2] != (HEIGHT, WIDTH):
|
|
||||||
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_NEAREST)
|
|
||||||
|
|
||||||
# Extract features
|
|
||||||
edge_feature = extract_edge_features(frame)
|
|
||||||
histogram_feature = extract_histogram_features(frame)
|
|
||||||
histogram_feature_image = np.full((HEIGHT, WIDTH), histogram_feature.mean()) # Convert histogram feature to image-like shape
|
|
||||||
combined_feature = np.stack([edge_feature, histogram_feature_image], axis=-1)
|
|
||||||
|
|
||||||
compressed_frame = frame / 255.0 # Assuming the frame is uint8, scale to [0, 1]
|
|
||||||
|
|
||||||
feature_batch.append(combined_feature)
|
|
||||||
compressed_frame_batch.append(compressed_frame)
|
|
||||||
|
|
||||||
if len(feature_batch) == batch_size:
|
|
||||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
|
||||||
feature_batch = []
|
|
||||||
compressed_frame_batch = []
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
|
|
||||||
# If there are frames left that don't fill a whole batch, send them anyway
|
|
||||||
if len(feature_batch) > 0:
|
|
||||||
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global BATCH_SIZE, EPOCHS, TRAIN_SAMPLES, LEARNING_RATE, MODEL_SAVE_FILE
|
global BATCH_SIZE, EPOCHS, LEARNING_RATE, MODEL_SAVE_FILE
|
||||||
# Argument parsing
|
# Argument parsing
|
||||||
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
parser = argparse.ArgumentParser(description="Train the video compression model.")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
parser.add_argument('-b', '--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training.')
|
||||||
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
parser.add_argument('-e', '--epochs', type=int, default=EPOCHS, help='Number of epochs for training.')
|
||||||
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
parser.add_argument('-l', '--learning_rate', type=float, default=LEARNING_RATE, help='Learning rate for training.')
|
||||||
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
parser.add_argument('-c', '--continue_training', type=str, nargs='?', const=MODEL_SAVE_FILE, default=None, help='Path to the saved model to continue training. If used without a value, defaults to the MODEL_SAVE_FILE.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
BATCH_SIZE = args.batch_size
|
BATCH_SIZE = args.batch_size
|
||||||
|
|
|
@ -1,23 +1,50 @@
|
||||||
# video_compression_model.py
|
# video_compression_model.py
|
||||||
|
|
||||||
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from featureExtraction import preprocess_frame
|
||||||
|
|
||||||
from global_train import LOGGER
|
from globalVars import HEIGHT, LOGGER, WIDTH
|
||||||
|
|
||||||
PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
#PRESET_SPEED_CATEGORIES = ["ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow"]
|
||||||
NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
#NUM_PRESET_SPEEDS = len(PRESET_SPEED_CATEGORIES)
|
||||||
NUM_CHANNELS = 3
|
|
||||||
WIDTH = 640
|
|
||||||
HEIGHT = 360
|
|
||||||
|
|
||||||
#from tensorflow.keras.mixed_precision import Policy
|
#from tensorflow.keras.mixed_precision import Policy
|
||||||
|
|
||||||
#policy = Policy('mixed_float16')
|
#policy = Policy('mixed_float16')
|
||||||
#tf.keras.mixed_precision.set_global_policy(policy)
|
#tf.keras.mixed_precision.set_global_policy(policy)
|
||||||
|
|
||||||
|
def data_generator(videos, batch_size):
|
||||||
|
while True:
|
||||||
|
for video_details in videos:
|
||||||
|
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
||||||
|
cap = cv2.VideoCapture(video_path)
|
||||||
|
|
||||||
|
feature_batch = []
|
||||||
|
compressed_frame_batch = []
|
||||||
|
|
||||||
|
while cap.isOpened():
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
combined_feature, compressed_frame = preprocess_frame(frame)
|
||||||
|
|
||||||
|
feature_batch.append(combined_feature)
|
||||||
|
compressed_frame_batch.append(compressed_frame)
|
||||||
|
|
||||||
|
if len(feature_batch) == batch_size:
|
||||||
|
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||||
|
feature_batch = []
|
||||||
|
compressed_frame_batch = []
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
# If there are frames left that don't fill a whole batch, send them anyway
|
||||||
|
if len(feature_batch) > 0:
|
||||||
|
yield (np.array(feature_batch), np.array(compressed_frame_batch))
|
||||||
|
|
||||||
class VideoCompressionModel(tf.keras.Model):
|
class VideoCompressionModel(tf.keras.Model):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Reference in a new issue