Skip to content
Snippets Groups Projects
Unverified Commit 7d3e8876 authored by Ahmet Öner's avatar Ahmet Öner Committed by GitHub
Browse files

Merge pull request #256 from aidancrowther/add-vram-flush

Add vram flushing support
parents 2143d1da 9e8f8e26
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ Unreleased
### Added
- Timeout configured to allow model to be unloaded when idle
- Added detection confidence to langauge detection endpoint
- Set mel generation to adjust n_dims automatically to match the loaded model
......
......
import os
from io import StringIO
from threading import Lock
from threading import Lock, Thread
from typing import BinaryIO, Union
import time
import gc
import torch
import whisper
......@@ -11,9 +13,28 @@ from .utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteV
model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
model = None
model_lock = Lock()
# More about available quantization levels is here:
# https://opennmt.net/CTranslate2/quantization.html
last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 0)) # default to being disabled
def monitor_idleness():
global model
if(idle_timeout <= 0): return
while True:
time.sleep(15)
if time.time() - last_activity_time > idle_timeout:
with model_lock:
release_model()
break
def load_model():
global model, device, model_quantization
if torch.cuda.is_available():
device = "cuda"
model_quantization = os.getenv("ASR_QUANTIZATION", "float32")
......@@ -25,8 +46,17 @@ model = WhisperModel(
model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
)
model_lock = Lock()
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
torch.cuda.empty_cache()
gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
......@@ -37,6 +67,12 @@ def transcribe(
word_timestamps: Union[bool, None],
output,
):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
......@@ -63,6 +99,12 @@ def transcribe(
def language_detection(audio):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
......
import os
from io import StringIO
from threading import Lock
from threading import Lock, Thread
from typing import BinaryIO, Union
import time
import gc
import torch
import whisper
......@@ -9,13 +11,41 @@ from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT,
model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
model = None
model_lock = Lock()
last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 0)) # default to being disabled
def monitor_idleness():
global model
if(idle_timeout <= 0): return
while True:
time.sleep(15) # check every minute
if time.time() - last_activity_time > idle_timeout:
with model_lock:
release_model()
break
def load_model():
global model
if torch.cuda.is_available():
model = whisper.load_model(model_name, download_root=model_path).cuda()
else:
model = whisper.load_model(model_name, download_root=model_path)
model_lock = Lock()
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
torch.cuda.empty_cache()
gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
......@@ -26,6 +56,12 @@ def transcribe(
word_timestamps: Union[bool, None],
output,
):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
......@@ -44,6 +80,12 @@ def transcribe(
def language_detection(audio):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
......
File changed. Contains only whitespace changes. Show whitespace changes.
......@@ -24,3 +24,11 @@ For English-only applications, the `.en` models tend to perform better, especial
```sh
export ASR_MODEL_PATH=/data/whisper
```
### Configuring the `Model Unloading Timeout`
```sh
export IDLE_TIMEOUT=300
```
Defaults to 0. After no activity for this period (in seconds), unload the model until it is requested again. Setting `0` disables the timeout, keeping the model loaded indefinitely.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment