Initial Commit
This commit is contained in:
parent
645b6c29f7
commit
c7306a9d48
4 changed files with 191 additions and 159 deletions
168
.gitignore
vendored
168
.gitignore
vendored
|
@ -1,162 +1,12 @@
|
|||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*
|
||||
!*/
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
!.github/**
|
||||
!LICENSE
|
||||
!README.md
|
||||
!.gitignore
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
!DeepEncode.py
|
||||
!train_model.py
|
||||
!video_compression_model.py
|
||||
|
|
64
DeepEncode.py
Normal file
64
DeepEncode.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from video_compression_model import VideoCompressionModel
|
||||
|
||||
# Constants
|
||||
NUM_CHANNELS = 3
|
||||
|
||||
# Step 2: Load the trained model
|
||||
model = tf.keras.models.load_model('ai_rate_control_model.keras', custom_objects={'VideoCompressionModel': VideoCompressionModel})
|
||||
|
||||
# Step 3: Load the uncompressed video
|
||||
UNCOMPRESSED_VIDEO_FILE = 'test_video.mkv'
|
||||
|
||||
def load_frames_from_video(video_file, num_frames = 0):
|
||||
print("Extracting video frames...")
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
frames = []
|
||||
count = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print("Max frames from file reached")
|
||||
break
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame)
|
||||
count += 1
|
||||
if num_frames == 0 or count >= num_frames:
|
||||
print("Max Frames wanted reached: ", num_frames)
|
||||
break
|
||||
cap.release()
|
||||
print("Extraction Complete")
|
||||
return frames
|
||||
|
||||
uncompressed_frames = load_frames_from_video(UNCOMPRESSED_VIDEO_FILE, 200)
|
||||
if len(uncompressed_frames) == 0 or None:
|
||||
print("IO ERROR!")
|
||||
exit()
|
||||
|
||||
uncompressed_frames = np.array(uncompressed_frames) / 255.0
|
||||
|
||||
if len(uncompressed_frames) == 0 or None:
|
||||
print("np.array ERROR!")
|
||||
exit()
|
||||
|
||||
# Step 4: Compress the video frames using the loaded model
|
||||
compressed_frames = model.predict(uncompressed_frames)
|
||||
|
||||
# Step 5: Save the compressed video frames
|
||||
COMPRESSED_VIDEO_FILE = 'compressed_video.mkv'
|
||||
|
||||
def save_frames_as_video(frames, video_file):
|
||||
print("Saving video frames...")
|
||||
height, width = frames[0].shape[:2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(video_file, fourcc, 24.0, (width, height))
|
||||
for frame in frames:
|
||||
frame = np.clip(frame * 255.0, 0, 255).astype(np.uint8)
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
out.write(frame)
|
||||
out.release()
|
||||
|
||||
save_frames_as_video(compressed_frames, COMPRESSED_VIDEO_FILE)
|
||||
print("Compression completed.")
|
91
train_model.py
Normal file
91
train_model.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
import os
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from video_compression_model import VideoCompressionModel
|
||||
|
||||
# Constants
|
||||
NUM_CHANNELS = 3 # Number of color channels in the video frames (RGB images have 3 channels)
|
||||
BATCH_SIZE = 32 # Batch size used during training
|
||||
EPOCHS = 20 # Number of training epochs
|
||||
CHECKPOINT_FILEPATH = "models/checkpoint-{epoch:02d}.keras"
|
||||
|
||||
# Step 1: Data Preparation
|
||||
TRAIN_VIDEO_FILE = 'native_video.mkv' # The training video file name
|
||||
VAL_VIDEO_FILE = 'training_video.mkv' # The validation video file name
|
||||
TRAIN_SAMPLES = 2 # Number of video frames used for training
|
||||
VAL_SAMPLES = 2 # Number of video frames used for validation
|
||||
|
||||
def load_frames_from_video(video_file, num_frames):
|
||||
print("Extracting video frames...")
|
||||
cap = cv2.VideoCapture(video_file)
|
||||
frames = []
|
||||
count = 0
|
||||
frame_width, frame_height = None, None # Initialize the frame dimensions
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if frame_width is None or frame_height is None:
|
||||
frame_height, frame_width = frame.shape[:2] # Get the frame dimensions from the first frame
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame)
|
||||
count += 1
|
||||
if count >= num_frames:
|
||||
break
|
||||
cap.release()
|
||||
return frames, frame_width, frame_height # Return frames and frame dimensions
|
||||
|
||||
train_frames, FRAME_WIDTH, FRAME_HEIGHT = load_frames_from_video(TRAIN_VIDEO_FILE, num_frames=TRAIN_SAMPLES)
|
||||
val_frames, _, _ = load_frames_from_video(VAL_VIDEO_FILE, num_frames=VAL_SAMPLES)
|
||||
|
||||
|
||||
print("Number of training frames:", len(train_frames))
|
||||
print("Number of validation frames:", len(val_frames))
|
||||
|
||||
def preprocess(frames):
|
||||
frames = np.array(frames) / 255.0
|
||||
return frames
|
||||
|
||||
train_frames = preprocess(train_frames)
|
||||
val_frames = preprocess(val_frames)
|
||||
|
||||
print("training frames:", len(train_frames))
|
||||
print("validation frames:", len(val_frames))
|
||||
|
||||
# Step 2: Model Architecture
|
||||
model = VideoCompressionModel()
|
||||
|
||||
model.compile(loss='mean_squared_error', optimizer='adam', run_eagerly=True)
|
||||
|
||||
# Adjusting the input shape for training and validation
|
||||
frame_height, frame_width = train_frames[0].shape[:2]
|
||||
|
||||
# Use the resized frames as target data
|
||||
train_targets = train_frames
|
||||
val_targets = val_frames
|
||||
|
||||
# Create the "models" directory if it doesn't exist
|
||||
os.makedirs("models", exist_ok=True)
|
||||
|
||||
# Create the ModelCheckpoint callback
|
||||
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=CHECKPOINT_FILEPATH,
|
||||
save_weights_only=False, # Save the entire model (including architecture)
|
||||
monitor='val_loss', # Metric to monitor for saving the best model (optional)
|
||||
save_best_only=True # Save only the best model based on the monitored metric (optional)
|
||||
)
|
||||
|
||||
print("\nTraining the model...")
|
||||
model.fit(
|
||||
train_frames, [train_targets, tf.zeros_like(train_targets)],
|
||||
batch_size=BATCH_SIZE,
|
||||
epochs=EPOCHS,
|
||||
validation_data=(val_frames, [val_targets, tf.zeros_like(val_targets)]),
|
||||
callbacks=[model_checkpoint_callback] # Add the ModelCheckpoint callback
|
||||
)
|
||||
print("\nTraining completed.")
|
||||
|
||||
# Step 3: Save the trained model
|
||||
model.save('ai_rate_control_model.keras')
|
||||
print("Model saved successfully!")
|
27
video_compression_model.py
Normal file
27
video_compression_model.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import tensorflow as tf
|
||||
|
||||
class VideoCompressionModel(tf.keras.Model):
|
||||
def __init__(self, NUM_CHANNELS=3):
|
||||
super(VideoCompressionModel, self).__init__()
|
||||
|
||||
# Encoder layers
|
||||
self.encoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(None, None, NUM_CHANNELS)),
|
||||
# Add more encoder layers as needed
|
||||
])
|
||||
|
||||
# Decoder layers
|
||||
self.decoder = tf.keras.Sequential([
|
||||
tf.keras.layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same'),
|
||||
# Add more decoder layers as needed
|
||||
tf.keras.layers.Conv2D(NUM_CHANNELS, (3, 3), activation='sigmoid', padding='same') # Output layer for video frames
|
||||
])
|
||||
|
||||
def call(self, inputs):
|
||||
# Encoding the video frames
|
||||
compressed_representation = self.encoder(inputs)
|
||||
|
||||
# Decoding to generate compressed video frames
|
||||
reconstructed_frames = self.decoder(compressed_representation)
|
||||
|
||||
return reconstructed_frames
|
Reference in a new issue