update
This commit is contained in:
parent
4d29fffba1
commit
8df4df7972
3 changed files with 51 additions and 20 deletions
|
@ -7,13 +7,13 @@ import numpy as np
|
|||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
|
||||
import tensorflow as tf
|
||||
from featureExtraction import combined, combined_loss, psnr, scale_crf, scale_speed_preset, ssim
|
||||
from featureExtraction import combined, combined_loss, combined_loss_weighted_psnr, psnr, scale_crf, scale_speed_preset, ssim
|
||||
from globalVars import PRESET_SPEED_CATEGORIES, clear_screen
|
||||
from video_compression_model import VideoCompressionModel, combine_batch
|
||||
|
||||
# Constants
|
||||
COMPRESSED_VIDEO_FILE = 'compressed_video.avi'
|
||||
MAX_FRAMES = 200 # Limit the number of frames processed
|
||||
MAX_FRAMES = 0 # Limit the number of frames processed
|
||||
CRF = 10
|
||||
SPEED = "ultrafast"
|
||||
MODEL_PATH = 'models/model.tf'
|
||||
|
@ -33,7 +33,7 @@ def parse_arguments():
|
|||
parser.add_argument('-p', '--model_path', default=MODEL_PATH, help='Path to the trained model')
|
||||
parser.add_argument('-i', '--uncompressed_video_file', default=UNCOMPRESSED_VIDEO_FILE, help='Path to the uncompressed video file')
|
||||
parser.add_argument('-d', '--display_output', action='store_true', default=DISPLAY_OUTPUT, help='Display real-time output to screen')
|
||||
parser.add_argument('--keep_black_bars', action='store_true', help='Keep black bars from the video', default=False)
|
||||
parser.add_argument('--keep_black_bars', action='store_false', help='Keep black bars from the video', default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -95,18 +95,35 @@ def load_frame_from_video(video_file, frame_num):
|
|||
|
||||
|
||||
def predict_frame(uncompressed_frame, model, crf, speed):
|
||||
# Scale the CRF and Speed values
|
||||
scaled_crf = scale_crf(crf)
|
||||
scaled_speed = scale_speed_preset(PRESET_SPEED_CATEGORIES.index(speed))
|
||||
frame = combine_batch(uncompressed_frame, scaled_crf, scaled_speed, resize=False)
|
||||
compressed_frame = model.predict([np.expand_dims(frame, axis=0)])[0]
|
||||
return np.clip(compressed_frame[:, :, :3] * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
# Preprocess the frame
|
||||
frame = combine_batch(uncompressed_frame, resize=False)
|
||||
|
||||
# Predict using the model
|
||||
inputs = {
|
||||
'image': np.expand_dims(frame, axis=0),
|
||||
'CRF': np.array([scaled_crf]),
|
||||
'Speed': np.array([scaled_speed])
|
||||
}
|
||||
compressed_frame = model.predict(inputs)[0]
|
||||
|
||||
# Post-process the output frame
|
||||
return np.clip(compressed_frame * 255.0, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss})
|
||||
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={'VideoCompressionModel': VideoCompressionModel, 'psnr': psnr, 'ssim': ssim, 'combined': combined, 'combined_loss': combined_loss, 'combined_loss_weighted_psnr': combined_loss_weighted_psnr})
|
||||
cap = cv2.VideoCapture(UNCOMPRESSED_VIDEO_FILE)
|
||||
|
||||
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
|
||||
if MAX_FRAMES > 0:
|
||||
total_frames = min(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), MAX_FRAMES)
|
||||
else:
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
height, width, fps = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
cap.release()
|
||||
|
@ -127,6 +144,7 @@ def main():
|
|||
|
||||
compressed_frame = predict_frame(uncompressed_frame, model, CRF, SPEED)
|
||||
compressed_frame = cv2.resize(compressed_frame, (width, height))
|
||||
compressed_frame = cv2.cvtColor(compressed_frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
out.write(compressed_frame)
|
||||
|
||||
|
|
|
@ -42,8 +42,8 @@ from globalVars import HEIGHT, WIDTH, MAX_FRAMES, LOGGER, clear_screen, load_vid
|
|||
|
||||
# Constants
|
||||
BATCH_SIZE = 25
|
||||
EPOCHS = 100
|
||||
LEARNING_RATE = 0.0001
|
||||
EPOCHS = 1000
|
||||
LEARNING_RATE = 0.005
|
||||
DECAY_STEPS = 160
|
||||
DECAY_RATE = 0.9
|
||||
MODEL_SAVE_FILE = "models/model.tf"
|
||||
|
@ -66,8 +66,10 @@ class ImageLoggingCallback(Callback):
|
|||
return np.stack(converted, axis=0)
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
# where total_batches is the number of batches in the validation dataset
|
||||
skip_batches = np.random.randint(0, 100)
|
||||
# Get the first batch from the validation dataset
|
||||
validation_data = next(iter(self.validation_dataset.take(1)))
|
||||
validation_data = next(iter(self.validation_dataset.skip(skip_batches).take(1)))
|
||||
|
||||
# Extract the inputs from the batch_input_images dictionary
|
||||
actual_images = validation_data[0]['image']
|
||||
|
@ -82,7 +84,7 @@ class ImageLoggingCallback(Callback):
|
|||
|
||||
# Save the reconstructed frame to the specified folder
|
||||
reconstructed_path = os.path.join(self.log_dir, f"epoch_{epoch}.png")
|
||||
cv2.imwrite(reconstructed_path, reconstructed_frame[0]) # Saving only the first image as an example
|
||||
cv2.imwrite(reconstructed_path, cv2.cvtColor(reconstructed_frame[0], cv2.COLOR_RGB2BGR)) # Saving only the first image as an example
|
||||
|
||||
# Log images to TensorBoard
|
||||
with self.writer.as_default():
|
||||
|
@ -145,13 +147,13 @@ def main():
|
|||
# Load all video metadata
|
||||
all_videos = load_video_metadata("test_data/validation/validation.json")
|
||||
|
||||
tf.random.set_seed(RANDOM_SEED)
|
||||
#tf.random.set_seed(RANDOM_SEED)
|
||||
|
||||
# Shuffle the data using the specified seed
|
||||
random.shuffle(all_videos, random.seed(RANDOM_SEED))
|
||||
|
||||
# Split into training and validation
|
||||
split_index = int(0.6 * len(all_videos))
|
||||
split_index = int(0.7 * len(all_videos))
|
||||
training_videos = all_videos[:split_index]
|
||||
validation_videos = all_videos[split_index:]
|
||||
|
||||
|
@ -166,7 +168,14 @@ def main():
|
|||
|
||||
|
||||
if args.continue_training:
|
||||
MODEL = tf.keras.models.load_model(args.continue_training)
|
||||
MODEL = tf.keras.models.load_model(args.continue_training, custom_objects={
|
||||
'VideoCompressionModel': VideoCompressionModel,
|
||||
'psnr': psnr,
|
||||
'ssim': ssim,
|
||||
'combined': combined,
|
||||
'combined_loss': combined_loss,
|
||||
'combined_loss_weighted_psnr': combined_loss_weighted_psnr
|
||||
})
|
||||
else:
|
||||
MODEL = VideoCompressionModel()
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ def create_dataset(videos, batch_size, max_frames=None):
|
|||
output_signature=output_signature
|
||||
)
|
||||
|
||||
dataset = dataset.batch(batch_size).shuffle(20).prefetch(1)
|
||||
dataset = dataset.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
|
@ -160,13 +160,16 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
# New shape: [batch_size, 1, 1, 128]
|
||||
crf_speed_features = tf.reshape(crf_speed_features, [-1, 1, 1, 128])
|
||||
|
||||
# Tile the tensor to match spatial dimensions of encoded tensor
|
||||
# Tiled shape: [batch_size, 90, 160, 128]
|
||||
crf_speed_features = tf.tile(crf_speed_features, [1, 90, 160, 1])
|
||||
|
||||
# Pass the image through the encoder
|
||||
encoded = self.encoder(image)
|
||||
|
||||
# Dynamically compute the spatial dimensions of the encoded tensor
|
||||
encoded_shape = tf.shape(encoded)
|
||||
height, width = encoded_shape[1], encoded_shape[2]
|
||||
|
||||
# Tile the crf_speed_features tensor to match the spatial dimensions of the encoded tensor
|
||||
crf_speed_features = tf.tile(crf_speed_features, [1, height, width, 1])
|
||||
|
||||
# Concatenate the encoded tensor with the crf_speed_features tensor
|
||||
combined_features = tf.concat([encoded, crf_speed_features], axis=-1)
|
||||
|
||||
|
@ -176,3 +179,4 @@ class VideoCompressionModel(tf.keras.Model):
|
|||
return decoded
|
||||
|
||||
|
||||
|
||||
|
|
Reference in a new issue