semi-working

This commit is contained in:
Jordon Brooks 2023-08-13 20:48:00 +01:00
parent e7af02cb4f
commit 54fa90247a
4 changed files with 38 additions and 11 deletions

View file

@ -16,7 +16,7 @@ from video_compression_model import VideoCompressionModel
COMPRESSED_VIDEO_FILE = 'compressed_video.avi' COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
MAX_FRAMES = 0 # Limit the number of frames processed MAX_FRAMES = 0 # Limit the number of frames processed
CRF = 51 CRF = 51
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast") SPEED = PRESET_SPEED_CATEGORIES.index("veryslow")
# Load the trained model # Load the trained model
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel}) MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
@ -30,21 +30,21 @@ def load_frame_from_video(video_file, frame_num):
ret, frame = cap.read() ret, frame = cap.read()
if not ret: if not ret:
return None return None
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cap.release() cap.release()
return frame return frame
def predict_frame(uncompressed_frame): def predict_frame(uncompressed_frame):
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
#cv2.imshow("uncomp", uncompressed_frame) cv2.imshow("uncomp", uncompressed_frame)
frame = preprocess_frame(uncompressed_frame, CRF, SPEED) frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0] compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
display_frame = np.clip(cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8) display_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
cv2.imshow("comp", display_frame) cv2.imshow("comp", display_frame)
cv2.waitKey(1) cv2.waitKey(1)

View file

@ -22,5 +22,29 @@
"original_video_file": "Scene5.mkv", "original_video_file": "Scene5.mkv",
"crf": 51, "crf": 51,
"preset_speed": "veryslow" "preset_speed": "veryslow"
},
{
"compressed_video_file": "Scene6_x264_crf-25_preset-medium.mkv",
"original_video_file": "Scene6.mkv",
"crf": 25,
"preset_speed": "medium"
},
{
"compressed_video_file": "Scene7_x264_crf-34_preset-slower.mkv",
"original_video_file": "Scene7.mkv",
"crf": 34,
"preset_speed": "slower"
},
{
"compressed_video_file": "Scene8_x264_crf-12_preset-faster.mkv",
"original_video_file": "Scene8.mkv",
"crf": 12,
"preset_speed": "faster"
},
{
"compressed_video_file": "Scene9_x264_crf-15_preset-slow.mkv",
"original_video_file": "Scene9.mkv",
"crf": 15,
"preset_speed": "slow"
} }
] ]

View file

@ -16,7 +16,7 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
# Constants # Constants
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 100 EPOCHS = 100
LEARNING_RATE = 0.001 LEARNING_RATE = 0.000001
MODEL_SAVE_FILE = "models/model.tf" MODEL_SAVE_FILE = "models/model.tf"
MODEL_CHECKPOINT_DIR = "checkpoints" MODEL_CHECKPOINT_DIR = "checkpoints"
EARLY_STOP = 10 EARLY_STOP = 10

View file

@ -19,8 +19,9 @@ def data_generator(videos, batch_size):
# Iterate over each video # Iterate over each video
for video_details in videos: for video_details in videos:
# Get the paths for compressed and original (uncompressed) video files # Get the paths for compressed and original (uncompressed) video files
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"]) base_dir = os.path.dirname("test_data/validation/validation.json")
uncompressed_video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["original_video_file"]) video_path = os.path.join(base_dir, video_details["compressed_video_file"])
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
CRF = video_details["crf"] / 51 CRF = video_details["crf"] / 51
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"]) SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
@ -87,12 +88,14 @@ class VideoCompressionModel(tf.keras.Model):
tf.keras.layers.UpSampling2D((2, 2)), tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'), tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
tf.keras.layers.UpSampling2D((2, 2)), tf.keras.layers.UpSampling2D((2, 2)),
tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same') tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS + 2, (3, 3), activation='sigmoid', padding='same')
]) ])
def call(self, inputs): def call(self, inputs):
# Encode the input #print("Input shape:", inputs.shape)
encoded = self.encoder(inputs) encoded = self.encoder(inputs)
# Decode the encoded representation #print("Encoded shape:", encoded.shape)
decoded = self.decoder(encoded) decoded = self.decoder(encoded)
#print("Decoded shape:", decoded.shape)
return decoded return decoded