Skip to content
Snippets Groups Projects
Commit e104eca0 authored by Aidan Crowther's avatar Aidan Crowther
Browse files

Re-enable model after unloading

parent b7a2b4a0
No related branches found
No related tags found
No related merge requests found
......@@ -13,21 +13,12 @@ from .utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteV
model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
model = None
model_lock = Lock()
# More about available quantization levels is here:
# https://opennmt.net/CTranslate2/quantization.html
if torch.cuda.is_available():
device = "cuda"
model_quantization = os.getenv("ASR_QUANTIZATION", "float32")
else:
device = "cpu"
model_quantization = os.getenv("ASR_QUANTIZATION", "int8")
model = WhisperModel(
model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
)
model_lock = Lock()
last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 300)) # default to 5 minutes
......@@ -40,13 +31,31 @@ def monitor_idleness():
release_model()
break
def load_model():
global model, device, model_quantization
if torch.cuda.is_available():
device = "cuda"
model_quantization = os.getenv("ASR_QUANTIZATION", "float32")
else:
device = "cpu"
model_quantization = os.getenv("ASR_QUANTIZATION", "int8")
model = WhisperModel(
model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
)
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
torch.cuda.empty_cache()
gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
......@@ -60,6 +69,9 @@ def transcribe(
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
......@@ -89,6 +101,9 @@ def language_detection(audio):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
......@@ -11,12 +11,9 @@ from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT,
model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
if torch.cuda.is_available():
model = whisper.load_model(model_name, download_root=model_path).cuda()
else:
model = whisper.load_model(model_name, download_root=model_path)
model = None
model_lock = Lock()
last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 300)) # default to 5 minutes
......@@ -29,13 +26,25 @@ def monitor_idleness():
release_model()
break
def load_model():
global model
if torch.cuda.is_available():
model = whisper.load_model(model_name, download_root=model_path).cuda()
else:
model = whisper.load_model(model_name, download_root=model_path)
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
torch.cuda.empty_cache()
gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
......@@ -49,6 +58,9 @@ def transcribe(
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
......@@ -70,6 +82,9 @@ def language_detection(audio):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment