Skip to content
Snippets Groups Projects
Commit 7b111d6e authored by Ahmet Öner's avatar Ahmet Öner
Browse files

Reformat the code and organize the imports.

parent 7d3e8876
No related branches found
No related tags found
No related merge requests found
import gc
import os
import time
from io import StringIO
from threading import Lock, Thread
from typing import BinaryIO, Union
import time
import gc
import torch
import whisper
......@@ -16,15 +16,13 @@ model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), "
model = None
model_lock = Lock()
# More about available quantization levels is here:
# https://opennmt.net/CTranslate2/quantization.html
last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 0)) # default to being disabled
def monitor_idleness():
global model
if(idle_timeout <= 0): return
if idle_timeout <= 0: return
while True:
time.sleep(15)
if time.time() - last_activity_time > idle_timeout:
......@@ -32,9 +30,12 @@ def monitor_idleness():
release_model()
break
def load_model():
global model, device, model_quantization
# More about available quantization levels is here:
# https://opennmt.net/CTranslate2/quantization.html
if torch.cuda.is_available():
device = "cuda"
model_quantization = os.getenv("ASR_QUANTIZATION", "float32")
......@@ -48,8 +49,10 @@ def load_model():
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
......@@ -58,6 +61,7 @@ def release_model():
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
task: Union[str, None],
......@@ -71,7 +75,7 @@ def transcribe(
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
if model is None: load_model()
options_dict = {"task": task}
if language:
......@@ -103,7 +107,7 @@ def language_detection(audio):
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
if model is None: load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
import gc
import os
import time
from io import StringIO
from threading import Lock, Thread
from typing import BinaryIO, Union
import time
import gc
import torch
import whisper
......@@ -17,16 +17,18 @@ model_lock = Lock()
last_activity_time = time.time()
idle_timeout = int(os.getenv("IDLE_TIMEOUT", 0)) # default to being disabled
def monitor_idleness():
global model
if(idle_timeout <= 0): return
if idle_timeout <= 0: return
while True:
time.sleep(15) # check every minute
time.sleep(15)
if time.time() - last_activity_time > idle_timeout:
with model_lock:
release_model()
break
def load_model():
global model
......@@ -37,8 +39,10 @@ def load_model():
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
......@@ -47,6 +51,7 @@ def release_model():
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
task: Union[str, None],
......@@ -60,7 +65,7 @@ def transcribe(
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
if model is None: load_model()
options_dict = {"task": task}
if language:
......@@ -84,7 +89,7 @@ def language_detection(audio):
last_activity_time = time.time()
with model_lock:
if(model is None): load_model()
if model is None: load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
......@@ -37,6 +37,7 @@ assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
app.mount("/assets", StaticFiles(directory=assets_path), name="static")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args,
......@@ -46,6 +47,7 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
swagger_js_url="/assets/swagger-ui-bundle.js",
)
applications.get_swagger_ui_html = swagger_monkey_patch
......@@ -90,7 +92,8 @@ async def detect_language(
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
):
detected_lang_code, confidence = language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code, "confidence": confidence}
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code,
"confidence": confidence}
def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
......@@ -125,6 +128,7 @@ def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@click.command()
@click.option(
"-h",
......@@ -147,5 +151,6 @@ def start(
):
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
start()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment