semi-working
This commit is contained in:
parent
e7af02cb4f
commit
54fa90247a
4 changed files with 38 additions and 11 deletions
|
@ -16,7 +16,7 @@ from video_compression_model import VideoCompressionModel
|
||||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||||
CRF = 51
|
CRF = 51
|
||||||
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
SPEED = PRESET_SPEED_CATEGORIES.index("veryslow")
|
||||||
|
|
||||||
# Load the trained model
|
# Load the trained model
|
||||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||||
|
@ -30,21 +30,21 @@ def load_frame_from_video(video_file, frame_num):
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
return None
|
return None
|
||||||
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
cap.release()
|
cap.release()
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
def predict_frame(uncompressed_frame):
|
def predict_frame(uncompressed_frame):
|
||||||
|
|
||||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||||
#cv2.imshow("uncomp", uncompressed_frame)
|
cv2.imshow("uncomp", uncompressed_frame)
|
||||||
|
|
||||||
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
|
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
|
||||||
|
|
||||||
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
|
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
|
||||||
|
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
|
||||||
|
|
||||||
display_frame = np.clip(cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
display_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
cv2.imshow("comp", display_frame)
|
cv2.imshow("comp", display_frame)
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
|
|
|
@ -22,5 +22,29 @@
|
||||||
"original_video_file": "Scene5.mkv",
|
"original_video_file": "Scene5.mkv",
|
||||||
"crf": 51,
|
"crf": 51,
|
||||||
"preset_speed": "veryslow"
|
"preset_speed": "veryslow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"compressed_video_file": "Scene6_x264_crf-25_preset-medium.mkv",
|
||||||
|
"original_video_file": "Scene6.mkv",
|
||||||
|
"crf": 25,
|
||||||
|
"preset_speed": "medium"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"compressed_video_file": "Scene7_x264_crf-34_preset-slower.mkv",
|
||||||
|
"original_video_file": "Scene7.mkv",
|
||||||
|
"crf": 34,
|
||||||
|
"preset_speed": "slower"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"compressed_video_file": "Scene8_x264_crf-12_preset-faster.mkv",
|
||||||
|
"original_video_file": "Scene8.mkv",
|
||||||
|
"crf": 12,
|
||||||
|
"preset_speed": "faster"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"compressed_video_file": "Scene9_x264_crf-15_preset-slow.mkv",
|
||||||
|
"original_video_file": "Scene9.mkv",
|
||||||
|
"crf": 15,
|
||||||
|
"preset_speed": "slow"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -16,7 +16,7 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
||||||
# Constants
|
# Constants
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 100
|
EPOCHS = 100
|
||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.000001
|
||||||
MODEL_SAVE_FILE = "models/model.tf"
|
MODEL_SAVE_FILE = "models/model.tf"
|
||||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||||
EARLY_STOP = 10
|
EARLY_STOP = 10
|
||||||
|
|
|
@ -19,8 +19,9 @@ def data_generator(videos, batch_size):
|
||||||
# Iterate over each video
|
# Iterate over each video
|
||||||
for video_details in videos:
|
for video_details in videos:
|
||||||
# Get the paths for compressed and original (uncompressed) video files
|
# Get the paths for compressed and original (uncompressed) video files
|
||||||
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
base_dir = os.path.dirname("test_data/validation/validation.json")
|
||||||
uncompressed_video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["original_video_file"])
|
video_path = os.path.join(base_dir, video_details["compressed_video_file"])
|
||||||
|
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
|
||||||
|
|
||||||
CRF = video_details["crf"] / 51
|
CRF = video_details["crf"] / 51
|
||||||
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
|
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
|
||||||
|
@ -87,12 +88,14 @@ class VideoCompressionModel(tf.keras.Model):
|
||||||
tf.keras.layers.UpSampling2D((2, 2)),
|
tf.keras.layers.UpSampling2D((2, 2)),
|
||||||
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
||||||
tf.keras.layers.UpSampling2D((2, 2)),
|
tf.keras.layers.UpSampling2D((2, 2)),
|
||||||
tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')
|
tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS + 2, (3, 3), activation='sigmoid', padding='same')
|
||||||
])
|
])
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
# Encode the input
|
#print("Input shape:", inputs.shape)
|
||||||
encoded = self.encoder(inputs)
|
encoded = self.encoder(inputs)
|
||||||
# Decode the encoded representation
|
#print("Encoded shape:", encoded.shape)
|
||||||
decoded = self.decoder(encoded)
|
decoded = self.decoder(encoded)
|
||||||
|
#print("Decoded shape:", decoded.shape)
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
|
Reference in a new issue