semi-working
This commit is contained in:
parent
e7af02cb4f
commit
54fa90247a
4 changed files with 38 additions and 11 deletions
|
@ -16,7 +16,7 @@ from video_compression_model import VideoCompressionModel
|
|||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||
CRF = 51
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index("ultrafast")
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index("veryslow")
|
||||
|
||||
# Load the trained model
|
||||
MODEL = tf.keras.models.load_model('models/model.tf', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||
|
@ -30,21 +30,21 @@ def load_frame_from_video(video_file, frame_num):
|
|||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
return None
|
||||
#frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
cap.release()
|
||||
|
||||
return frame
|
||||
|
||||
def predict_frame(uncompressed_frame):
|
||||
|
||||
#display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
#cv2.imshow("uncomp", uncompressed_frame)
|
||||
display_frame = np.clip(cv2.cvtColor(uncompressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
cv2.imshow("uncomp", uncompressed_frame)
|
||||
|
||||
frame = preprocess_frame(uncompressed_frame, CRF, SPEED)
|
||||
|
||||
compressed_frame = MODEL.predict([np.expand_dims(frame, axis=0)])[0]
|
||||
compressed_frame = compressed_frame[:, :, :3] # Keep only the first 3 channels (BGR)
|
||||
|
||||
display_frame = np.clip(cv2.cvtColor(compressed_frame, cv2.COLOR_BGR2RGB) * 255.0, 0, 255).astype(np.uint8)
|
||||
display_frame = np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
cv2.imshow("comp", display_frame)
|
||||
cv2.waitKey(1)
|
||||
|
|
|
@ -22,5 +22,29 @@
|
|||
"original_video_file": "Scene5.mkv",
|
||||
"crf": 51,
|
||||
"preset_speed": "veryslow"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene6_x264_crf-25_preset-medium.mkv",
|
||||
"original_video_file": "Scene6.mkv",
|
||||
"crf": 25,
|
||||
"preset_speed": "medium"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene7_x264_crf-34_preset-slower.mkv",
|
||||
"original_video_file": "Scene7.mkv",
|
||||
"crf": 34,
|
||||
"preset_speed": "slower"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene8_x264_crf-12_preset-faster.mkv",
|
||||
"original_video_file": "Scene8.mkv",
|
||||
"crf": 12,
|
||||
"preset_speed": "faster"
|
||||
},
|
||||
{
|
||||
"compressed_video_file": "Scene9_x264_crf-15_preset-slow.mkv",
|
||||
"original_video_file": "Scene9.mkv",
|
||||
"crf": 15,
|
||||
"preset_speed": "slow"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -16,7 +16,7 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER
|
|||
# Constants
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.001
|
||||
LEARNING_RATE = 0.000001
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
MODEL_CHECKPOINT_DIR = "checkpoints"
|
||||
EARLY_STOP = 10
|
||||
|
|
|
@ -19,8 +19,9 @@ def data_generator(videos, batch_size):
|
|||
# Iterate over each video
|
||||
for video_details in videos:
|
||||
# Get the paths for compressed and original (uncompressed) video files
|
||||
video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["compressed_video_file"])
|
||||
uncompressed_video_path = os.path.join(os.path.dirname("test_data/validation/validation.json"), video_details["original_video_file"])
|
||||
base_dir = os.path.dirname("test_data/validation/validation.json")
|
||||
video_path = os.path.join(base_dir, video_details["compressed_video_file"])
|
||||
uncompressed_video_path = os.path.join(base_dir, video_details["original_video_file"])
|
||||
|
||||
CRF = video_details["crf"] / 51
|
||||
SPEED = PRESET_SPEED_CATEGORIES.index(video_details["preset_speed"])
|
||||
|
@ -87,12 +88,14 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
tf.keras.layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'),
|
||||
tf.keras.layers.UpSampling2D((2, 2)),
|
||||
tf.keras.layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')
|
||||
tf.keras.layers.Conv2DTranspose(NUM_COLOUR_CHANNELS + 2, (3, 3), activation='sigmoid', padding='same')
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
# Encode the input
|
||||
#print("Input shape:", inputs.shape)
|
||||
encoded = self.encoder(inputs)
|
||||
# Decode the encoded representation
|
||||
#print("Encoded shape:", encoded.shape)
|
||||
decoded = self.decoder(encoded)
|
||||
#print("Decoded shape:", decoded.shape)
|
||||
return decoded
|
||||
|
||||
|
|
Reference in a new issue