From c7306a9d4827e1d73687ae673efce2640415ef7a Mon Sep 17 00:00:00 2001 From: Jordon Brooks Date: Mon, 24 Jul 2023 16:47:07 +0100 Subject: [PATCH] Initial Commit --- .gitignore | 168 ++----------------------------------- DeepEncode.py | 64 ++++++++++++++ train_model.py | 91 ++++++++++++++++++++ video_compression_model.py | 27 ++++++ 4 files changed, 191 insertions(+), 159 deletions(-) create mode 100644 DeepEncode.py create mode 100644 train_model.py create mode 100644 video_compression_model.py diff --git a/.gitignore b/.gitignore index 5d381cc..dae2ae5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,162 +1,12 @@ -# ---> Python -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class +* +!*/ -# C extensions -*.so +!.github/** +!LICENSE +!README.md +!.gitignore -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +!DeepEncode.py +!train_model.py +!video_compression_model.py diff --git a/DeepEncode.py b/DeepEncode.py new file mode 100644 index 0000000..fc7ba24 --- /dev/null +++ b/DeepEncode.py @@ -0,0 +1,64 @@ +import tensorflow as tf +import numpy as np +import cv2 +from video_compression_model import VideoCompressionModel + +# Constants +NUM_CHANNELS = 3 + +# Step 2: Load the trained model +model = tf.keras.models.load_model('ai_rate_control_model.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel}) + +# Step 3: Load the uncompressed video +UNCOMPRESSED_VIDEO_FILE = 'test_video.mkv' + +def load_frames_from_video(video_file, num_frames = 0): + print("Extracting video frames...") + cap = cv2.VideoCapture(video_file) + frames = [] + count = 0 + while True: + ret, frame = cap.read() + if not ret: + print("Max frames from file reached") + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + count += 1 + if num_frames == 0 or count >= num_frames: + print("Max Frames wanted reached: ", num_frames) + break + cap.release() + print("Extraction Complete") + return frames + +uncompressed_frames = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, 200) +if len(uncompressed_frames) == 0 or None: + print("IO ERROR!") + exit() + +uncompressed_frames = np.array(uncompressed_frames) / 255.0 + +if len(uncompressed_frames) == 0 or None: + print("np.array ERROR!") + exit() + +# Step 4: Compress the video frames using the loaded model +compressed_frames = model.predict(uncompressed_frames) + +# Step 5: Save the compressed video frames +COMPRESSED_VIDEO_FILE = 'compressed_video.mkv' + +def save_frames_as_video(frames, video_file): + print("Saving video frames...") + height, width = frames[0].shape[:2] + fourcc = cv2.VideoWriter_fourcc(*'XVID') + out = cv2.VideoWriter(video_file, fourcc, 24.0, (width, height)) + for frame in frames: + frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8) + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame) + out.release() + +save_frames_as_video(compressed_frames, COMPRESSED_VIDEO_FILE) +print("Compression completed.") diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..837124a --- /dev/null +++ b/train_model.py @@ -0,0 +1,91 @@ +import os +import tensorflow as tf +import numpy as np +import cv2 +from video_compression_model import VideoCompressionModel + +# Constants +NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels) +BATCH_SIZE = 32 # Batch size used during training +EPOCHS = 20 # Number of training epochs +CHECKPOINT_FILEPATH = "models/checkpoint-{epoch:02d}.keras" + +# Step 1: Data Preparation +TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name +VAL_VIDEO_FILE = 'training_video.mkv' # The validation video file name +TRAIN_SAMPLES = 2 # Number of video frames used for training +VAL_SAMPLES = 2 # Number of video frames used for validation + +def load_frames_from_video(video_file, num_frames): + print("Extracting video frames...") + cap = cv2.VideoCapture(video_file) + frames = [] + count = 0 + frame_width, frame_height = None, None # Initialize the frame dimensions + while True: + ret, frame = cap.read() + if not ret: + break + if frame_width is None or frame_height is None: + frame_height, frame_width = frame.shape[:2] # Get the frame dimensions from the first frame + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + count += 1 + if count >= num_frames: + break + cap.release() + return frames, frame_width, frame_height # Return frames and frame dimensions + +train_frames, FRAME_WIDTH, FRAME_HEIGHT = load_frames_from_video(TRAIN_VIDEO_FILE, num_frames=TRAIN_SAMPLES) +val_frames, _, _ = load_frames_from_video(VAL_VIDEO_FILE, num_frames=VAL_SAMPLES) + + +print("Number of training frames:", len(train_frames)) +print("Number of validation frames:", len(val_frames)) + +def preprocess(frames): + frames = np.array(frames) / 255.0 + return frames + +train_frames = preprocess(train_frames) +val_frames = preprocess(val_frames) + +print("training frames:", len(train_frames)) +print("validation frames:", len(val_frames)) + +# Step 2: Model Architecture +model = VideoCompressionModel() + +model.compile(loss='mean_squared_error', optimizer='adam', run_eagerly=True) + +# Adjusting the input shape for training and validation +frame_height, frame_width = train_frames[0].shape[:2] + +# Use the resized frames as target data +train_targets = train_frames +val_targets = val_frames + +# Create the "models" directory if it doesn't exist +os.makedirs("models", exist_ok=True) + +# Create the ModelCheckpoint callback +model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + filepath=CHECKPOINT_FILEPATH, + save_weights_only=False, # Save the entire model (including architecture) + monitor='val_loss', # Metric to monitor for saving the best model (optional) + save_best_only=True # Save only the best model based on the monitored metric (optional) +) + +print("\nTraining the model...") +model.fit( + train_frames, [train_targets, tf.zeros_like(train_targets)], + batch_size=BATCH_SIZE, + epochs=EPOCHS, + validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)]), + callbacks=[model_checkpoint_callback] # Add the ModelCheckpoint callback +) +print("\nTraining completed.") + +# Step 3: Save the trained model +model.save('ai_rate_control_model.keras') +print("Model saved successfully!") diff --git a/video_compression_model.py b/video_compression_model.py new file mode 100644 index 0000000..7f49848 --- /dev/null +++ b/video_compression_model.py @@ -0,0 +1,27 @@ +import tensorflow as tf + +class VideoCompressionModel(tf.keras.Model): + def __init__(self, NUM_CHANNELS=3): + super(VideoCompressionModel, self).__init__() + + # Encoder layers + self.encoder = tf.keras.Sequential([ + tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(None, None, NUM_CHANNELS)), + # Add more encoder layers as needed + ]) + + # Decoder layers + self.decoder = tf.keras.Sequential([ + tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'), + # Add more decoder layers as needed + tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames + ]) + + def call(self, inputs): + # Encoding the video frames + compressed_representation = self.encoder(inputs) + + # Decoding to generate compressed video frames + reconstructed_frames = self.decoder(compressed_representation) + + return reconstructed_frames \ No newline at end of file