Skip to content
Snippets Groups Projects
Commit 4c003ca4 authored by Ahmet Oner's avatar Ahmet Oner
Browse files

Add black and ruff formatters

parent 76b041a5
No related branches found
No related tags found
No related merge requests found
...@@ -103,9 +103,11 @@ Unreleased ...@@ -103,9 +103,11 @@ Unreleased
- Updated default model paths to `~/.cache/whisper` or `/root/.cache/whisper`. - Updated default model paths to `~/.cache/whisper` or `/root/.cache/whisper`.
- For customization, modify the `ASR_MODEL_PATH` environment variable. - For customization, modify the `ASR_MODEL_PATH` environment variable.
- Ensure Docker volume is set for the corresponding directory to use caching. - Ensure Docker volume is set for the corresponding directory to use caching.
```bash ```bash
docker run -d -p 9000:9000 -e ASR_MODEL_PATH=/data/whisper -v $PWD/yourlocaldir:/data/whisper onerahmet/openai-whisper-asr-webservice:latest docker run -d -p 9000:9000 -e ASR_MODEL_PATH=/data/whisper -v $PWD/yourlocaldir:/data/whisper onerahmet/openai-whisper-asr-webservice:latest
``` ```
- Removed the `triton` dependency from `poetry.lock` to ensure the stability of the pipeline for `ARM-based` Docker images - Removed the `triton` dependency from `poetry.lock` to ensure the stability of the pipeline for `ARM-based` Docker images
[1.1.1] (2023-05-29) [1.1.1] (2023-05-29)
......
...@@ -2,16 +2,17 @@ ...@@ -2,16 +2,17 @@
![Docker Pulls](https://img.shields.io/docker/pulls/onerahmet/openai-whisper-asr-webservice.svg) ![Docker Pulls](https://img.shields.io/docker/pulls/onerahmet/openai-whisper-asr-webservice.svg)
![Build](https://img.shields.io/github/actions/workflow/status/ahmetoner/whisper-asr-webservice/docker-publish.yml.svg) ![Build](https://img.shields.io/github/actions/workflow/status/ahmetoner/whisper-asr-webservice/docker-publish.yml.svg)
![Licence](https://img.shields.io/github/license/ahmetoner/whisper-asr-webservice.svg) ![Licence](https://img.shields.io/github/license/ahmetoner/whisper-asr-webservice.svg)
# Whisper ASR Webservice # Whisper ASR Webservice
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitask model that can perform multilingual speech recognition as well as speech translation and language identification. For more details: [github.com/openai/whisper](https://github.com/openai/whisper/) Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitask model that can perform multilingual speech recognition as well as speech translation and language identification. For more details: [github.com/openai/whisper](https://github.com/openai/whisper/)
## Features ## Features
Current release (v1.4.1) supports following whisper models: Current release (v1.4.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20231117](https://github.com/openai/whisper/releases/tag/v20231117) - [openai/whisper](https://github.com/openai/whisper)@[v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v0.10.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/0.10.0) - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.0.2](https://github.com/SYSTRAN/faster-whisper/releases/tag/1.0.2)
## Quick Usage ## Quick Usage
...@@ -33,6 +34,7 @@ for more information: ...@@ -33,6 +34,7 @@ for more information:
- [Docker Hub](https://hub.docker.com/r/onerahmet/openai-whisper-asr-webservice) - [Docker Hub](https://hub.docker.com/r/onerahmet/openai-whisper-asr-webservice)
## Documentation ## Documentation
Explore the documentation by clicking [here](https://ahmetoner.github.io/whisper-asr-webservice). Explore the documentation by clicking [here](https://ahmetoner.github.io/whisper-asr-webservice).
## Credits ## Credits
......
import os import os
from io import StringIO from io import StringIO
from threading import Lock from threading import Lock
from typing import Union, BinaryIO from typing import BinaryIO, Union
import torch import torch
import whisper import whisper
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from .utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON from .utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
model_name = os.getenv("ASR_MODEL", "base") model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper")) model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
...@@ -23,14 +22,12 @@ else: ...@@ -23,14 +22,12 @@ else:
model_quantization = os.getenv("ASR_QUANTIZATION", "int8") model_quantization = os.getenv("ASR_QUANTIZATION", "int8")
model = WhisperModel( model = WhisperModel(
model_size_or_path=model_name, model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
device=device,
compute_type=model_quantization,
download_root=model_path
) )
model_lock = Lock() model_lock = Lock()
def transcribe( def transcribe(
audio, audio,
task: Union[str, None], task: Union[str, None],
...@@ -56,11 +53,7 @@ def transcribe( ...@@ -56,11 +53,7 @@ def transcribe(
for segment in segment_generator: for segment in segment_generator:
segments.append(segment) segments.append(segment)
text = text + segment.text text = text + segment.text
result = { result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}
"language": options_dict.get("language", info.language),
"segments": segments,
"text": text
}
output_file = StringIO() output_file = StringIO()
write_result(result, output_file, output) write_result(result, output_file, output)
...@@ -81,9 +74,7 @@ def language_detection(audio): ...@@ -81,9 +74,7 @@ def language_detection(audio):
return detected_lang_code return detected_lang_code
def write_result( def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
result: dict, file: BinaryIO, output: Union[str, None]
):
if output == "srt": if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file) WriteSRT(ResultWriter).write_result(result, file=file)
elif output == "vtt": elif output == "vtt":
...@@ -95,4 +86,4 @@ def write_result( ...@@ -95,4 +86,4 @@ def write_result(
elif output == "txt": elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file) WriteTXT(ResultWriter).write_result(result, file=file)
else: else:
return 'Please select an output method!' return "Please select an output method!"
...@@ -69,6 +69,7 @@ class WriteTSV(ResultWriter): ...@@ -69,6 +69,7 @@ class WriteTSV(ResultWriter):
an environment setting a language encoding that causes the decimal in a floating point number an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
""" """
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO):
......
...@@ -5,7 +5,7 @@ from typing import BinaryIO, Union ...@@ -5,7 +5,7 @@ from typing import BinaryIO, Union
import torch import torch
import whisper import whisper
from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
model_name = os.getenv("ASR_MODEL", "base") model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper")) model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
...@@ -24,7 +24,7 @@ def transcribe( ...@@ -24,7 +24,7 @@ def transcribe(
initial_prompt: Union[str, None], initial_prompt: Union[str, None],
vad_filter: Union[bool, None], vad_filter: Union[bool, None],
word_timestamps: Union[bool, None], word_timestamps: Union[bool, None],
output output,
): ):
options_dict = {"task": task} options_dict = {"task": task}
if language: if language:
...@@ -58,14 +58,8 @@ def language_detection(audio): ...@@ -58,14 +58,8 @@ def language_detection(audio):
return detected_lang_code return detected_lang_code
def write_result( def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
result: dict, file: BinaryIO, output: Union[str, None] options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
):
options = {
'max_line_width': 1000,
'max_line_count': 10,
'highlight_words': False
}
if output == "srt": if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file, options=options) WriteSRT(ResultWriter).write_result(result, file=file, options=options)
elif output == "vtt": elif output == "vtt":
...@@ -77,4 +71,4 @@ def write_result( ...@@ -77,4 +71,4 @@ def write_result(
elif output == "txt": elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file, options=options) WriteTXT(ResultWriter).write_result(result, file=file, options=options)
else: else:
return 'Please select an output method!' return "Please select an output method!"
import importlib.metadata import importlib.metadata
import os import os
from os import path from os import path
from typing import BinaryIO, Union, Annotated from typing import Annotated, BinaryIO, Union
from urllib.parse import quote
import ffmpeg import ffmpeg
import numpy as np import numpy as np
from fastapi import FastAPI, File, UploadFile, Query, applications from fastapi import FastAPI, File, Query, UploadFile, applications
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.responses import StreamingResponse, RedirectResponse from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from whisper import tokenizer from whisper import tokenizer
from urllib.parse import quote
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper") ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
if ASR_ENGINE == "faster_whisper": if ASR_ENGINE == "faster_whisper":
from .faster_whisper.core import transcribe, language_detection from .faster_whisper.core import language_detection, transcribe
else: else:
from .openai_whisper.core import transcribe, language_detection from .openai_whisper.core import language_detection, transcribe
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys())) LANGUAGE_CODES = sorted(tokenizer.LANGUAGES.keys())
projectMetadata = importlib.metadata.metadata('whisper-asr-webservice') projectMetadata = importlib.metadata.metadata("whisper-asr-webservice")
app = FastAPI( app = FastAPI(
title=projectMetadata['Name'].title().replace('-', ' '), title=projectMetadata["Name"].title().replace("-", " "),
description=projectMetadata['Summary'], description=projectMetadata["Summary"],
version=projectMetadata['Version'], version=projectMetadata["Version"],
contact={ contact={"url": projectMetadata["Home-page"]},
"url": projectMetadata['Home-page']
},
swagger_ui_parameters={"defaultModelsExpandDepth": -1}, swagger_ui_parameters={"defaultModelsExpandDepth": -1},
license_info={ license_info={"name": "MIT License", "url": projectMetadata["License"]},
"name": "MIT License",
"url": projectMetadata['License']
}
) )
assets_path = os.getcwd() + "/swagger-ui-assets" assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"): if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
app.mount("/assets", StaticFiles(directory=assets_path), name="static") app.mount("/assets", StaticFiles(directory=assets_path), name="static")
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, *args,
...@@ -50,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/ ...@@ -50,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
swagger_js_url="/assets/swagger-ui-bundle.js", swagger_js_url="/assets/swagger-ui-bundle.js",
) )
applications.get_swagger_ui_html = swagger_monkey_patch applications.get_swagger_ui_html = swagger_monkey_patch
...@@ -61,33 +54,38 @@ async def index(): ...@@ -61,33 +54,38 @@ async def index():
@app.post("/asr", tags=["Endpoints"]) @app.post("/asr", tags=["Endpoints"])
async def asr( async def asr(
audio_file: UploadFile = File(...), audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES), language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
initial_prompt: Union[str, None] = Query(default=None), initial_prompt: Union[str, None] = Query(default=None),
vad_filter: Annotated[bool | None, Query( vad_filter: Annotated[
bool | None,
Query(
description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech", description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False) include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False),
)] = False, ),
] = False,
word_timestamps: bool = Query(default=False, description="Word level timestamps"), word_timestamps: bool = Query(default=False, description="Word level timestamps"),
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]) output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
): ):
result = transcribe(load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output) result = transcribe(
load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
)
return StreamingResponse( return StreamingResponse(
result, result,
media_type="text/plain", media_type="text/plain",
headers={ headers={
'Asr-Engine': ASR_ENGINE, "Asr-Engine": ASR_ENGINE,
'Content-Disposition': f'attachment; filename="{quote(audio_file.filename)}.{output}"' "Content-Disposition": f'attachment; filename="{quote(audio_file.filename)}.{output}"',
} },
) )
@app.post("/detect-language", tags=["Endpoints"]) @app.post("/detect-language", tags=["Endpoints"])
async def detect_language( async def detect_language(
audio_file: UploadFile = File(...), audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through FFmpeg") encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
): ):
detected_lang_code = language_detection(load_audio(audio_file.file, encode)) detected_lang_code = language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code} return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code}
......
...@@ -77,4 +77,3 @@ poetry run gunicorn --bind 0.0.0.0:9000 --workers 1 --timeout 0 app.webservice:a ...@@ -77,4 +77,3 @@ poetry run gunicorn --bind 0.0.0.0:9000 --workers 1 --timeout 0 app.webservice:a
```sh ```sh
docker-compose up --build -f docker-compose.gpu.yml docker-compose up --build -f docker-compose.gpu.yml
``` ```
...@@ -29,7 +29,8 @@ There are 2 endpoints available: ...@@ -29,7 +29,8 @@ There are 2 endpoints available:
| encode | true (default) | | encode | true (default) |
Example request with cURL Example request with cURL
```
```bash
curl -X POST -H "content-type: multipart/form-data" -F "audio_file=@/path/to/file" 0.0.0.0:9000/asr?output=json curl -X POST -H "content-type: multipart/form-data" -F "audio_file=@/path/to/file" 0.0.0.0:9000/asr?output=json
``` ```
......
...@@ -19,7 +19,6 @@ Available ASR_MODELs are `tiny`, `base`, `small`, `medium`, `large` (only OpenAI ...@@ -19,7 +19,6 @@ Available ASR_MODELs are `tiny`, `base`, `small`, `medium`, `large` (only OpenAI
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models. For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
### Configuring the `Model Path` ### Configuring the `Model Path`
```sh ```sh
......
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitask model that can perform multilingual speech recognition as well as speech translation and language identification. Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitask model that can perform multilingual speech recognition as well as speech translation and language identification.
## Features ## Features
Current release (v1.4.1) supports following whisper models: Current release (v1.4.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20231117](https://github.com/openai/whisper/releases/tag/v20231117) - [openai/whisper](https://github.com/openai/whisper)@[v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v0.10.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/0.10.0) - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.0.2](https://github.com/SYSTRAN/faster-whisper/releases/tag/1.0.2)
## Quick Usage ## Quick Usage
......
...@@ -29,11 +29,12 @@ Docker Hub: <https://hub.docker.com/r/onerahmet/openai-whisper-asr-webservice> ...@@ -29,11 +29,12 @@ Docker Hub: <https://hub.docker.com/r/onerahmet/openai-whisper-asr-webservice>
docker run -d --gpus all -p 9000:9000 -e ASR_MODEL=base -e ASR_ENGINE=openai_whisper onerahmet/openai-whisper-asr-webservice:latest-gpu docker run -d --gpus all -p 9000:9000 -e ASR_MODEL=base -e ASR_ENGINE=openai_whisper onerahmet/openai-whisper-asr-webservice:latest-gpu
``` ```
> Interactive Swagger API documentation is available at http://localhost:9000/docs > Interactive Swagger API documentation is available at <http://localhost:9000/docs>
![Swagger UI](assets/images/swagger-ui.png) ![Swagger UI](assets/images/swagger-ui.png)
## Cache ## Cache
The ASR model is downloaded each time you start the container, using the large model this can take some time. The ASR model is downloaded each time you start the container, using the large model this can take some time.
If you want to decrease the time it takes to start your container by skipping the download, you can store the cache directory (`~/.cache/whisper` or `/root/.cache/whisper`) to a persistent storage. If you want to decrease the time it takes to start your container by skipping the download, you can store the cache directory (`~/.cache/whisper` or `/root/.cache/whisper`) to a persistent storage.
Next time you start your container the ASR Model will be taken from the cache instead of being downloaded again. Next time you start your container the ASR Model will be taken from the cache instead of being downloaded again.
......
This diff is collapsed.
...@@ -4,10 +4,7 @@ version = "1.5.0-dev" ...@@ -4,10 +4,7 @@ version = "1.5.0-dev"
description = "Whisper ASR Webservice is a general-purpose speech recognition webservice." description = "Whisper ASR Webservice is a general-purpose speech recognition webservice."
homepage = "https://github.com/ahmetoner/whisper-asr-webservice/" homepage = "https://github.com/ahmetoner/whisper-asr-webservice/"
license = "https://github.com/ahmetoner/whisper-asr-webservice/blob/main/LICENCE" license = "https://github.com/ahmetoner/whisper-asr-webservice/blob/main/LICENCE"
authors = [ authors = ["Ahmet Öner", "Besim Alibegovic"]
"Ahmet Öner",
"Besim Alibegovic",
]
readme = "README.md" readme = "README.md"
packages = [{ include = "app" }] packages = [{ include = "app" }]
...@@ -39,7 +36,47 @@ torch = [ ...@@ -39,7 +36,47 @@ torch = [
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^6.2.5" pytest = "^6.2.5"
ruff = "^0.5.0"
black = "^24.4.2"
mkdocs = "^1.6.0"
mkdocs-material = "^9.5.27"
pymdown-extensions = "^10.8.1"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.black]
skip-string-normalization = true
line-length = 120
[tool.ruff]
# Same as Black.
line-length = 120
[tool.ruff.lint]
select = [
"E", # pycodestyle errors (settings from FastAPI, thanks, @tiangolo!)
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = [
"E501", # line too long, handled by black
"C901", # too complex
]
[tool.ruff.lint.isort]
order-by-type = true
relative-imports-order = "closest-to-furthest"
extra-standard-library = ["typing"]
section-order = [
"future",
"standard-library",
"third-party",
"first-party",
"local-folder",
]
known-first-party = []
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment