Skip to content
Snippets Groups Projects
Commit e104eca0 authored by Aidan Crowther's avatar Aidan Crowther
Browse files

Re-enable model after unloading

parent b7a2b4a0
Branches
No related tags found
No related merge requests found
...@@ -13,21 +13,12 @@ from .utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteV ...@@ -13,21 +13,12 @@ from .utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteV
model_name = os.getenv("ASR_MODEL", "base") model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper")) model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
model = None
model_lock = Lock()
# More about available quantization levels is here: # More about available quantization levels is here:
# https://opennmt.net/CTranslate2/quantization.html # https://opennmt.net/CTranslate2/quantization.html
if torch.cuda.is_available():
device = "cuda"
model_quantization = os.getenv("ASR_QUANTIZATION", "float32")
else:
device = "cpu"
model_quantization = os.getenv("ASR_QUANTIZATION", "int8")
model = WhisperModel(
model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
)
model_lock = Lock()
last_activity_time = time.time() last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 300)) # default to 5 minutes idle_timeout = int(os.getenv("IDLE_TIMEOUT", 300)) # default to 5 minutes
...@@ -40,13 +31,31 @@ def monitor_idleness(): ...@@ -40,13 +31,31 @@ def monitor_idleness():
release_model() release_model()
break break
def load_model():
global model, device, model_quantization
if torch.cuda.is_available():
device = "cuda"
model_quantization = os.getenv("ASR_QUANTIZATION", "float32")
else:
device = "cpu"
model_quantization = os.getenv("ASR_QUANTIZATION", "int8")
model = WhisperModel(
model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
)
Thread(target=monitor_idleness, daemon=True).start() Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model(): def release_model():
global model global model
del model del model
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe( def transcribe(
audio, audio,
...@@ -60,6 +69,9 @@ def transcribe( ...@@ -60,6 +69,9 @@ def transcribe(
global last_activity_time global last_activity_time
last_activity_time = time.time() last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
options_dict = {"task": task} options_dict = {"task": task}
if language: if language:
options_dict["language"] = language options_dict["language"] = language
...@@ -89,6 +101,9 @@ def language_detection(audio): ...@@ -89,6 +101,9 @@ def language_detection(audio):
global last_activity_time global last_activity_time
last_activity_time = time.time() last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
# load audio and pad/trim it to fit 30 seconds # load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio) audio = whisper.pad_or_trim(audio)
......
...@@ -11,12 +11,9 @@ from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, ...@@ -11,12 +11,9 @@ from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT,
model_name = os.getenv("ASR_MODEL", "base") model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper")) model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
model = None
if torch.cuda.is_available():
model = whisper.load_model(model_name, download_root=model_path).cuda()
else:
model = whisper.load_model(model_name, download_root=model_path)
model_lock = Lock() model_lock = Lock()
last_activity_time = time.time() last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 300)) # default to 5 minutes idle_timeout = int(os.getenv("IDLE_TIMEOUT", 300)) # default to 5 minutes
...@@ -29,13 +26,25 @@ def monitor_idleness(): ...@@ -29,13 +26,25 @@ def monitor_idleness():
release_model() release_model()
break break
def load_model():
global model
if torch.cuda.is_available():
model = whisper.load_model(model_name, download_root=model_path).cuda()
else:
model = whisper.load_model(model_name, download_root=model_path)
Thread(target=monitor_idleness, daemon=True).start() Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model(): def release_model():
global model global model
del model del model
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe( def transcribe(
audio, audio,
...@@ -49,6 +58,9 @@ def transcribe( ...@@ -49,6 +58,9 @@ def transcribe(
global last_activity_time global last_activity_time
last_activity_time = time.time() last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
options_dict = {"task": task} options_dict = {"task": task}
if language: if language:
options_dict["language"] = language options_dict["language"] = language
...@@ -70,6 +82,9 @@ def language_detection(audio): ...@@ -70,6 +82,9 @@ def language_detection(audio):
global last_activity_time global last_activity_time
last_activity_time = time.time() last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
# load audio and pad/trim it to fit 30 seconds # load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio) audio = whisper.pad_or_trim(audio)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment