Skip to content
Snippets Groups Projects
Unverified Commit 22765007 authored by Ahmet Öner's avatar Ahmet Öner Committed by GitHub
Browse files

Merge pull request #229 from ahmetoner/add-formatters

Add black and ruff formatters
parents 76b041a5 4c003ca4
No related branches found
No related tags found
No related merge requests found
......@@ -103,9 +103,11 @@ Unreleased
- Updated default model paths to `~/.cache/whisper` or `/root/.cache/whisper`.
- For customization, modify the `ASR_MODEL_PATH` environment variable.
- Ensure Docker volume is set for the corresponding directory to use caching.
```bash
docker run -d -p 9000:9000 -e ASR_MODEL_PATH=/data/whisper -v $PWD/yourlocaldir:/data/whisper onerahmet/openai-whisper-asr-webservice:latest
```
- Removed the `triton` dependency from `poetry.lock` to ensure the stability of the pipeline for `ARM-based` Docker images
[1.1.1] (2023-05-29)
......
......@@ -2,16 +2,17 @@
![Docker Pulls](https://img.shields.io/docker/pulls/onerahmet/openai-whisper-asr-webservice.svg)
![Build](https://img.shields.io/github/actions/workflow/status/ahmetoner/whisper-asr-webservice/docker-publish.yml.svg)
![Licence](https://img.shields.io/github/license/ahmetoner/whisper-asr-webservice.svg)
# Whisper ASR Webservice
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitask model that can perform multilingual speech recognition as well as speech translation and language identification. For more details: [github.com/openai/whisper](https://github.com/openai/whisper/)
## Features
Current release (v1.4.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v0.10.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/0.10.0)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.0.2](https://github.com/SYSTRAN/faster-whisper/releases/tag/1.0.2)
## Quick Usage
......@@ -33,6 +34,7 @@ for more information:
- [Docker Hub](https://hub.docker.com/r/onerahmet/openai-whisper-asr-webservice)
## Documentation
Explore the documentation by clicking [here](https://ahmetoner.github.io/whisper-asr-webservice).
## Credits
......
import os
from io import StringIO
from threading import Lock
from typing import Union, BinaryIO
from typing import BinaryIO, Union
import torch
import whisper
from faster_whisper import WhisperModel
from .utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
from .utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
......@@ -23,14 +22,12 @@ else:
model_quantization = os.getenv("ASR_QUANTIZATION", "int8")
model = WhisperModel(
model_size_or_path=model_name,
device=device,
compute_type=model_quantization,
download_root=model_path
model_size_or_path=model_name, device=device, compute_type=model_quantization, download_root=model_path
)
model_lock = Lock()
def transcribe(
audio,
task: Union[str, None],
......@@ -56,11 +53,7 @@ def transcribe(
for segment in segment_generator:
segments.append(segment)
text = text + segment.text
result = {
"language": options_dict.get("language", info.language),
"segments": segments,
"text": text
}
result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}
output_file = StringIO()
write_result(result, output_file, output)
......@@ -81,9 +74,7 @@ def language_detection(audio):
return detected_lang_code
def write_result(
result: dict, file: BinaryIO, output: Union[str, None]
):
def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file)
elif output == "vtt":
......@@ -95,4 +86,4 @@ def write_result(
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file)
else:
return 'Please select an output method!'
return "Please select an output method!"
......@@ -69,6 +69,7 @@ class WriteTSV(ResultWriter):
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
......
......@@ -5,7 +5,7 @@ from typing import BinaryIO, Union
import torch
import whisper
from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
model_name = os.getenv("ASR_MODEL", "base")
model_path = os.getenv("ASR_MODEL_PATH", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
......@@ -24,7 +24,7 @@ def transcribe(
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
output
output,
):
options_dict = {"task": task}
if language:
......@@ -58,14 +58,8 @@ def language_detection(audio):
return detected_lang_code
def write_result(
result: dict, file: BinaryIO, output: Union[str, None]
):
options = {
'max_line_width': 1000,
'max_line_count': 10,
'highlight_words': False
}
def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file, options=options)
elif output == "vtt":
......@@ -77,4 +71,4 @@ def write_result(
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file, options=options)
else:
return 'Please select an output method!'
return "Please select an output method!"
import importlib.metadata
import os
from os import path
from typing import BinaryIO, Union, Annotated
from typing import Annotated, BinaryIO, Union
from urllib.parse import quote
import ffmpeg
import numpy as np
from fastapi import FastAPI, File, UploadFile, Query, applications
from fastapi import FastAPI, File, Query, UploadFile, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from whisper import tokenizer
from urllib.parse import quote
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
if ASR_ENGINE == "faster_whisper":
from .faster_whisper.core import transcribe, language_detection
from .faster_whisper.core import language_detection, transcribe
else:
from .openai_whisper.core import transcribe, language_detection
from .openai_whisper.core import language_detection, transcribe
SAMPLE_RATE = 16000
LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys()))
LANGUAGE_CODES = sorted(tokenizer.LANGUAGES.keys())
projectMetadata = importlib.metadata.metadata('whisper-asr-webservice')
projectMetadata = importlib.metadata.metadata("whisper-asr-webservice")
app = FastAPI(
title=projectMetadata['Name'].title().replace('-', ' '),
description=projectMetadata['Summary'],
version=projectMetadata['Version'],
contact={
"url": projectMetadata['Home-page']
},
title=projectMetadata["Name"].title().replace("-", " "),
description=projectMetadata["Summary"],
version=projectMetadata["Version"],
contact={"url": projectMetadata["Home-page"]},
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
license_info={
"name": "MIT License",
"url": projectMetadata['License']
}
license_info={"name": "MIT License", "url": projectMetadata["License"]},
)
assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
app.mount("/assets", StaticFiles(directory=assets_path), name="static")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args,
......@@ -50,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
swagger_js_url="/assets/swagger-ui-bundle.js",
)
applications.get_swagger_ui_html = swagger_monkey_patch
......@@ -61,33 +54,38 @@ async def index():
@app.post("/asr", tags=["Endpoints"])
async def asr(
audio_file: UploadFile = File(...),
audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
initial_prompt: Union[str, None] = Query(default=None),
vad_filter: Annotated[bool | None, Query(
vad_filter: Annotated[
bool | None,
Query(
description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False)
)] = False,
include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False),
),
] = False,
word_timestamps: bool = Query(default=False, description="Word level timestamps"),
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"])
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
):
result = transcribe(load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output)
result = transcribe(
load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
)
return StreamingResponse(
result,
media_type="text/plain",
headers={
'Asr-Engine': ASR_ENGINE,
'Content-Disposition': f'attachment; filename="{quote(audio_file.filename)}.{output}"'
}
"Asr-Engine": ASR_ENGINE,
"Content-Disposition": f'attachment; filename="{quote(audio_file.filename)}.{output}"',
},
)
@app.post("/detect-language", tags=["Endpoints"])
async def detect_language(
audio_file: UploadFile = File(...),
encode: bool = Query(default=True, description="Encode audio first through FFmpeg")
audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
):
detected_lang_code = language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code}
......
......@@ -77,4 +77,3 @@ poetry run gunicorn --bind 0.0.0.0:9000 --workers 1 --timeout 0 app.webservice:a
```sh
docker-compose up --build -f docker-compose.gpu.yml
```
......@@ -29,7 +29,8 @@ There are 2 endpoints available:
| encode | true (default) |
Example request with cURL
```
```bash
curl -X POST -H "content-type: multipart/form-data" -F "audio_file=@/path/to/file" 0.0.0.0:9000/asr?output=json
```
......
......@@ -19,7 +19,6 @@ Available ASR_MODELs are `tiny`, `base`, `small`, `medium`, `large` (only OpenAI
For English-only applications, the `.en` models tend to perform better, especially for the `tiny.en` and `base.en` models. We observed that the difference becomes less significant for the `small.en` and `medium.en` models.
### Configuring the `Model Path`
```sh
......
Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitask model that can perform multilingual speech recognition as well as speech translation and language identification.
## Features
Current release (v1.4.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20231117](https://github.com/openai/whisper/releases/tag/v20231117)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v0.10.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/0.10.0)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.0.2](https://github.com/SYSTRAN/faster-whisper/releases/tag/1.0.2)
## Quick Usage
......
......@@ -29,11 +29,12 @@ Docker Hub: <https://hub.docker.com/r/onerahmet/openai-whisper-asr-webservice>
docker run -d --gpus all -p 9000:9000 -e ASR_MODEL=base -e ASR_ENGINE=openai_whisper onerahmet/openai-whisper-asr-webservice:latest-gpu
```
> Interactive Swagger API documentation is available at http://localhost:9000/docs
> Interactive Swagger API documentation is available at <http://localhost:9000/docs>
![Swagger UI](assets/images/swagger-ui.png)
## Cache
The ASR model is downloaded each time you start the container, using the large model this can take some time.
If you want to decrease the time it takes to start your container by skipping the download, you can store the cache directory (`~/.cache/whisper` or `/root/.cache/whisper`) to a persistent storage.
Next time you start your container the ASR Model will be taken from the cache instead of being downloaded again.
......
This diff is collapsed.
......@@ -4,10 +4,7 @@ version = "1.5.0-dev"
description = "Whisper ASR Webservice is a general-purpose speech recognition webservice."
homepage = "https://github.com/ahmetoner/whisper-asr-webservice/"
license = "https://github.com/ahmetoner/whisper-asr-webservice/blob/main/LICENCE"
authors = [
"Ahmet Öner",
"Besim Alibegovic",
]
authors = ["Ahmet Öner", "Besim Alibegovic"]
readme = "README.md"
packages = [{ include = "app" }]
......@@ -39,7 +36,47 @@ torch = [
[tool.poetry.dev-dependencies]
pytest = "^6.2.5"
ruff = "^0.5.0"
black = "^24.4.2"
mkdocs = "^1.6.0"
mkdocs-material = "^9.5.27"
pymdown-extensions = "^10.8.1"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.black]
skip-string-normalization = true
line-length = 120
[tool.ruff]
# Same as Black.
line-length = 120
[tool.ruff.lint]
select = [
"E", # pycodestyle errors (settings from FastAPI, thanks, @tiangolo!)
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
]
ignore = [
"E501", # line too long, handled by black
"C901", # too complex
]
[tool.ruff.lint.isort]
order-by-type = true
relative-imports-order = "closest-to-furthest"
extra-standard-library = ["typing"]
section-order = [
"future",
"standard-library",
"third-party",
"first-party",
"local-folder",
]
known-first-party = []
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment