updated
This commit is contained in:
parent
60c6c97071
commit
93ccce5ec1
6 changed files with 46 additions and 34 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -12,5 +12,5 @@
|
||||||
!video_compression_model.py
|
!video_compression_model.py
|
||||||
!global_train.py
|
!global_train.py
|
||||||
!log.py
|
!log.py
|
||||||
!test_data/training.json
|
!test_data/training/training.json
|
||||||
!test_data/validation.json
|
!test_data/validation/validation.json
|
|
@ -3,14 +3,14 @@
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from video_compression_model import VideoCompressionModel
|
from video_compression_model import PRESET_SPEED_CATEGORIES, VideoCompressionModel
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
CHUNK_SIZE = 10 # Adjust based on available memory and video resolution
|
CHUNK_SIZE = 24 # Adjust based on available memory and video resolution
|
||||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||||
CRF = 25.0 # Example CRF value
|
CRF = 24.0 # Example CRF value
|
||||||
PRESET_SPEED = 4 # Index for "fast" in our defined list
|
PRESET_SPEED = "veryslow" # Index for "fast" in our defined list
|
||||||
|
|
||||||
# Load the trained model
|
# Load the trained model
|
||||||
model = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
model = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||||
|
@ -42,6 +42,7 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value):
|
||||||
|
|
||||||
compressed_frame = model.predict({
|
compressed_frame = model.predict({
|
||||||
"compressed_frame": uncompressed_frame,
|
"compressed_frame": uncompressed_frame,
|
||||||
|
"uncompressed_frame": uncompressed_frame,
|
||||||
"crf": crf_array,
|
"crf": crf_array,
|
||||||
"preset_speed": preset_speed_array
|
"preset_speed": preset_speed_array
|
||||||
})
|
})
|
||||||
|
@ -49,7 +50,7 @@ def predict_frame(uncompressed_frame, model, crf_value, preset_speed_value):
|
||||||
display_frame = np.clip(cv2.cvtColor(compressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
display_frame = np.clip(cv2.cvtColor(compressed_frame[0], cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
cv2.imshow("comp", display_frame)
|
cv2.imshow("comp", display_frame)
|
||||||
cv2.waitKey(10)
|
cv2.waitKey(1)
|
||||||
|
|
||||||
return compressed_frame[0]
|
return compressed_frame[0]
|
||||||
|
|
||||||
|
@ -70,7 +71,7 @@ if MAX_FRAMES != 0 and total_frames > MAX_FRAMES:
|
||||||
|
|
||||||
for i in range(total_frames):
|
for i in range(total_frames):
|
||||||
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
uncompressed_frame = load_frame_from_video(UNCOMPRESSED_VIDEO_FILE, frame_num=i)
|
||||||
compressed_frame = predict_frame(uncompressed_frame, model, CRF, PRESET_SPEED)
|
compressed_frame = predict_frame(uncompressed_frame, model, CRF, PRESET_SPEED_CATEGORIES.index(PRESET_SPEED))
|
||||||
|
|
||||||
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
compressed_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||||
|
|
|
@ -1,73 +1,73 @@
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-ultrafast.mkv",
|
"video_file": "x264_crf-51_preset-ultrafast.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "ultrafast"
|
"preset_speed": "ultrafast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-16_preset-veryslow.mkv",
|
"video_file": "x264_crf-16_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 16,
|
"crf": 16,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-18_preset-ultrafast.mkv",
|
"video_file": "x264_crf-18_preset-ultrafast.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 18,
|
"crf": 18,
|
||||||
"preset_speed": "ultrafast"
|
"preset_speed": "ultrafast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-18_preset-veryslow.mkv",
|
"video_file": "x264_crf-18_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 18,
|
"crf": 18,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-50_preset-veryslow.mkv",
|
"video_file": "x264_crf-50_preset-veryslow.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 50,
|
"crf": 50,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-fast.mkv",
|
"video_file": "x264_crf-51_preset-fast.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "fast"
|
"preset_speed": "fast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-faster.mkv",
|
"video_file": "x264_crf-51_preset-faster.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "faster"
|
"preset_speed": "faster"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-medium.mkv",
|
"video_file": "x264_crf-51_preset-medium.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "medium"
|
"preset_speed": "medium"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-slow.mkv",
|
"video_file": "x264_crf-51_preset-slow.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "slow"
|
"preset_speed": "slow"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-slower.mkv",
|
"video_file": "x264_crf-51_preset-slower.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "slower"
|
"preset_speed": "slower"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-superfast.mkv",
|
"video_file": "x264_crf-51_preset-superfast.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "superfast"
|
"preset_speed": "superfast"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"video_file": "x264_crf-51_preset-veryfast.mkv",
|
"video_file": "x264_crf-51_preset-veryfast.mkv",
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
"uncompressed_video_file": "../x264_crf-5_preset-veryslow.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "veryfast"
|
"preset_speed": "veryfast"
|
||||||
}
|
}
|
|
@ -1,8 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"video_file": "x264_crf-16_preset-veryslow.mkv",
|
|
||||||
"uncompressed_video_file": "x264_crf-5_preset-veryslow.mkv",
|
|
||||||
"crf": 16,
|
|
||||||
"preset_speed": "veryslow"
|
|
||||||
}
|
|
||||||
]
|
|
9
test_data/validation/validation.json
Normal file
9
test_data/validation/validation.json
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
[
|
||||||
|
|
||||||
|
{
|
||||||
|
"video_file": "Scene2_x264_crf-51_preset-veryslow.mkv",
|
||||||
|
"uncompressed_video_file": "Scene2_x264_crf-5_preset-veryslow.mkv",
|
||||||
|
"crf": 51,
|
||||||
|
"preset_speed": "veryslow"
|
||||||
|
}
|
||||||
|
]
|
|
@ -18,10 +18,12 @@ from global_train import LOGGER
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 4
|
||||||
EPOCHS = 100
|
EPOCHS = 100
|
||||||
LEARNING_RATE = 0.000001
|
LEARNING_RATE = 0.000001
|
||||||
TRAIN_SAMPLES = 50
|
TRAIN_SAMPLES = 100
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 10
|
||||||
|
WIDTH = 638
|
||||||
|
HEIGHT = 360
|
||||||
|
|
||||||
def load_video_metadata(list_path):
|
def load_video_metadata(list_path):
|
||||||
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
LOGGER.trace(f"Entering: load_video_metadata({list_path})")
|
||||||
|
@ -67,11 +69,11 @@ def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
||||||
compressed_frames, uncompressed_frames = [], []
|
compressed_frames, uncompressed_frames = [], []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cap = cv2.VideoCapture(os.path.join("test_data/", video_file))
|
cap = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), video_file))
|
||||||
cap_uncompressed = cv2.VideoCapture(os.path.join("test_data/", uncompressed_video_file))
|
cap_uncompressed = cv2.VideoCapture(os.path.join(os.path.dirname(list_path), uncompressed_video_file))
|
||||||
|
|
||||||
if not cap.isOpened() or not cap_uncompressed.isOpened():
|
if not cap.isOpened() or not cap_uncompressed.isOpened():
|
||||||
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}")
|
raise RuntimeError(f"Could not open video files {video_file} or {uncompressed_video_file}, searched under: {os.path.dirname(list_path)}")
|
||||||
|
|
||||||
for _ in range(frames_per_video):
|
for _ in range(frames_per_video):
|
||||||
ret, frame_compressed = cap.read()
|
ret, frame_compressed = cap.read()
|
||||||
|
@ -79,6 +81,14 @@ def load_video_samples(list_path, samples=TRAIN_SAMPLES):
|
||||||
|
|
||||||
if not ret or not ret_uncompressed:
|
if not ret or not ret_uncompressed:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check frame dimensions and resize if necessary
|
||||||
|
if frame.shape[:2] != (WIDTH, HEIGHT):
|
||||||
|
LOGGER.warn(f"Resizing video: {video_file}")
|
||||||
|
frame = cv2.resize(frame, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||||
|
if frame_compressed.shape[:2] != (WIDTH, HEIGHT):
|
||||||
|
LOGGER.warn(f"Resizing video: {uncompressed_video_file}")
|
||||||
|
frame_compressed = cv2.resize(frame_compressed, (WIDTH, HEIGHT), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
frame_compressed = cv2.cvtColor(frame_compressed, cv2.COLOR_BGR2RGB)
|
||||||
|
@ -149,8 +159,8 @@ def main():
|
||||||
|
|
||||||
# Load training and validation samples
|
# Load training and validation samples
|
||||||
LOGGER.debug("Loading training and validation samples.")
|
LOGGER.debug("Loading training and validation samples.")
|
||||||
training_samples = load_video_samples("test_data/training.json")
|
training_samples = load_video_samples("test_data/training/training.json")
|
||||||
validation_samples = load_video_samples("test_data/validation.json", args.training_samples // 2)
|
validation_samples = load_video_samples("test_data/validation/validation.json", args.training_samples // 2)
|
||||||
|
|
||||||
train_generator = VideoDataGenerator(training_samples, args.batch_size)
|
train_generator = VideoDataGenerator(training_samples, args.batch_size)
|
||||||
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
|
val_generator = VideoDataGenerator(validation_samples, args.batch_size)
|
||||||
|
|
Reference in a new issue