diff --git a/train_model.py b/train_model.py index dd93f80..b122989 100644 --- a/train_model.py +++ b/train_model.py @@ -14,8 +14,18 @@ from featureExtraction import psnr os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' +import gc import tensorflow as tf -from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint +from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + except RuntimeError as e: + print(e) + from video_compression_model import VideoCompressionModel, data_generator