From 0cd35b893d7beb2f715966d2bdf8fe3db40b72aa Mon Sep 17 00:00:00 2001 From: charnesp <charles.nespoulous@gmail.com> Date: Sun, 16 Feb 2025 18:33:31 +0100 Subject: [PATCH] Integrate Whisperx (#267) * Initial commit for whisperx * whisperx working * Correct Dockerfile * Add possibility to prioritize self-hosted runners --------- Co-authored-by: Charles N <charles@chouette.vision> --- .devcontainer/devcontainer.json | 44 ++++++++++ .devcontainer/docker-compose.yml | 30 +++++++ .github/workflows/docker-publish.yml | 6 +- .gitignore | 4 +- Dockerfile | 13 +++ Dockerfile.gpu | 12 +++ README.md | 1 + app/asr_models/asr_model.py | 27 ++++-- app/asr_models/faster_whisper_engine.py | 1 + app/asr_models/mbain_whisperx_engine.py | 111 ++++++++++++++++++++++++ app/asr_models/openai_whisper_engine.py | 33 ++++--- app/config.py | 3 + app/factory/asr_model_factory.py | 3 + app/utils.py | 66 ++++++++++---- app/webservice.py | 74 ++++++++++------ 15 files changed, 356 insertions(+), 72 deletions(-) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .devcontainer/docker-compose.yml create mode 100644 app/asr_models/mbain_whisperx_engine.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..36e62bd --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,44 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose +{ + "name": "Existing Docker Compose (Extend)", + + // Update the 'dockerComposeFile' list if you have more compose files or use different names. + // The .devcontainer/docker-compose.yml file contains any overrides you need/want to make. + "dockerComposeFile": [ + "../docker-compose.yml", + "docker-compose.yml" + ], + + // The 'service' property is the name of the service for the container that VS Code should + // use. Update this value and .devcontainer/docker-compose.yml to the real service name. + "service": "whisper-asr-webservice", + + // The optional 'workspaceFolder' property is the path VS Code should open by default when + // connected. This is typically a file mount in .devcontainer/docker-compose.yml + "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", + + // "overrideCommand": "/bin/sh -c 'while sleep 1000; do :; done'" + "overrideCommand": true + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Uncomment the next line if you want start specific services in your Docker Compose config. + // "runServices": [], + + // Uncomment the next line if you want to keep your containers running after VS Code shuts down. + // "shutdownAction": "none", + + // Uncomment the next line to run commands after the container is created. + // "postCreateCommand": "cat /etc/os-release", + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "devcontainer" +} diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 0000000..c3e29c5 --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3.4' +services: + # Update this to the name of the service you want to work with in your docker-compose.yml file + whisper-asr-webservice: + # Uncomment if you want to override the service's Dockerfile to one in the .devcontainer + # folder. Note that the path of the Dockerfile and context is relative to the *primary* + # docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile" + # array). The sample below assumes your primary file is in the root of your project. + # + # build: + # context: . + # dockerfile: .devcontainer/Dockerfile + env_file: .devcontainer/dev.env + environment: + ASR_ENGINE: ${ASR_ENGINE} + HF_TOKEN: ${HF_TOKEN} + + volumes: + # Update this to wherever you want VS Code to mount the folder of your project + - ..:/workspaces:cached + + # Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust. + # cap_add: + # - SYS_PTRACE + # security_opt: + # - seccomp:unconfined + + # Overrides default command so things don't shut down after the process ends. + command: sleep infinity + diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 571d2b8..574ea12 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -12,7 +12,7 @@ env: REPO_NAME: ${{secrets.REPO_NAME}} jobs: build: - runs-on: ubuntu-latest + runs-on: [self-hosted, ubuntu-latest] strategy: matrix: include: @@ -22,6 +22,10 @@ jobs: tag_extension: -gpu platforms: linux/amd64 steps: + - name: Remove unnecessary files + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf "$AGENT_TOOLSDIRECTORY" - name: Checkout uses: actions/checkout@v3 - name: Set up QEMU diff --git a/.gitignore b/.gitignore index 4dbf939..3a43ec4 100644 --- a/.gitignore +++ b/.gitignore @@ -41,4 +41,6 @@ pip-wheel-metadata poetry/core/* -public \ No newline at end of file +public + +.devcontainer/dev.env diff --git a/Dockerfile b/Dockerfile index bdae06f..08cfa3a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \ pkg-config \ yasm \ ca-certificates \ + gcc \ + python3-dev \ && rm -rf /var/lib/apt/lists/* RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1 @@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui FROM python:3.10-bookworm +RUN export DEBIAN_FRONTEND=noninteractive \ + && apt-get -qq update \ + && apt-get -qq install --no-install-recommends \ + libsndfile1 \ + && rm -rf /var/lib/apt/lists/* + ENV POETRY_VENV=/app/.venv RUN python3 -m venv $POETRY_VENV \ @@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass RUN poetry config virtualenvs.in-project true RUN poetry install +RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio +RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \ + && cd whisperX \ + && $POETRY_VENV/bin/pip install -e . + EXPOSE 9000 ENTRYPOINT ["whisper-asr-webservice"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu index b605707..709681f 100644 --- a/Dockerfile.gpu +++ b/Dockerfile.gpu @@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 ENV PYTHON_VERSION=3.10 + +RUN export DEBIAN_FRONTEND=noninteractive \ + && apt-get -qq update \ + && apt-get -qq install --no-install-recommends \ + libsndfile1 \ + && rm -rf /var/lib/apt/lists/* + ENV POETRY_VENV=/app/.venv RUN export DEBIAN_FRONTEND=noninteractive \ @@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass RUN poetry install RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch +RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio +RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \ + && cd whisperX \ + && $POETRY_VENV/bin/pip install -e . + EXPOSE 9000 CMD whisper-asr-webservice diff --git a/README.md b/README.md index b521283..21dcb00 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models: - [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930) - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0) +- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1) ## Quick Usage diff --git a/app/asr_models/asr_model.py b/app/asr_models/asr_model.py index fc22050..8ec9890 100644 --- a/app/asr_models/asr_model.py +++ b/app/asr_models/asr_model.py @@ -13,7 +13,10 @@ class ASRModel(ABC): """ Abstract base class for ASR (Automatic Speech Recognition) models. """ + model = None + diarize_model = None # used for WhisperX + x_models = dict() # used for WhisperX model_lock = Lock() last_activity_time = time.time() @@ -28,14 +31,17 @@ class ASRModel(ABC): pass @abstractmethod - def transcribe(self, - audio, - task: Union[str, None], - language: Union[str, None], - initial_prompt: Union[str, None], - vad_filter: Union[bool, None], - word_timestamps: Union[bool, None] - ): + def transcribe( + self, + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + vad_filter: Union[bool, None], + word_timestamps: Union[bool, None], + options: Union[dict, None], + output, + ): """ Perform transcription on the given audio file. """ @@ -52,7 +58,8 @@ class ASRModel(ABC): """ Monitors the idleness of the ASR model and releases the model if it has been idle for too long. """ - if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return + if CONFIG.MODEL_IDLE_TIMEOUT <= 0: + return while True: time.sleep(15) if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT: @@ -68,4 +75,6 @@ class ASRModel(ABC): torch.cuda.empty_cache() gc.collect() self.model = None + self.diarize_model = None + self.x_models = dict() print("Model unloaded due to timeout") diff --git a/app/asr_models/faster_whisper_engine.py b/app/asr_models/faster_whisper_engine.py index a0b8620..3cf21f8 100644 --- a/app/asr_models/faster_whisper_engine.py +++ b/app/asr_models/faster_whisper_engine.py @@ -32,6 +32,7 @@ class FasterWhisperASR(ASRModel): initial_prompt: Union[str, None], vad_filter: Union[bool, None], word_timestamps: Union[bool, None], + options: Union[dict, None], output, ): self.last_activity_time = time.time() diff --git a/app/asr_models/mbain_whisperx_engine.py b/app/asr_models/mbain_whisperx_engine.py new file mode 100644 index 0000000..0e104bf --- /dev/null +++ b/app/asr_models/mbain_whisperx_engine.py @@ -0,0 +1,111 @@ +from typing import BinaryIO, Union +from io import StringIO +import whisperx +import whisper +from whisperx.utils import SubtitlesWriter, ResultWriter + +from app.asr_models.asr_model import ASRModel +from app.config import CONFIG +from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON + + +class WhisperXASR(ASRModel): + def __init__(self): + self.x_models = dict() + + def load_model(self): + + asr_options = {"without_timestamps": False} + self.model = whisperx.load_model( + CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options + ) + + if CONFIG.HF_TOKEN != "": + self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE) + + def transcribe( + self, + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + vad_filter: Union[bool, None], + word_timestamps: Union[bool, None], + options: Union[dict, None], + output, + ): + options_dict = {"task": task} + if language: + options_dict["language"] = language + if initial_prompt: + options_dict["initial_prompt"] = initial_prompt + with self.model_lock: + if self.model is None: + self.load_model() + result = self.model.transcribe(audio, **options_dict) + + # Load the required model and cache it + # If we transcribe models in many different languages, this may lead to OOM propblems + if result["language"] in self.x_models: + model_x, metadata = self.x_models[result["language"]] + else: + self.x_models[result["language"]] = whisperx.load_align_model( + language_code=result["language"], device=CONFIG.DEVICE + ) + model_x, metadata = self.x_models[result["language"]] + + # Align whisper output + result = whisperx.align( + result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False + ) + + if options.get("diarize", False): + if CONFIG.HF_TOKEN == "": + print("Warning! HF_TOKEN is not set. Diarization may not work as expected.") + min_speakers = options.get("min_speakers", None) + max_speakers = options.get("max_speakers", None) + # add min/max number of speakers if known + diarize_segments = self.diarize_model(audio, min_speakers, max_speakers) + result = whisperx.assign_word_speakers(diarize_segments, result) + + output_file = StringIO() + self.write_result(result, output_file, output) + output_file.seek(0) + + return output_file + + def language_detection(self, audio): + # load audio and pad/trim it to fit 30 seconds + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(self.model.device) + + # detect the spoken language + with self.model_lock: + if self.model is None: + self.load_model() + _, probs = self.model.detect_language(mel) + detected_lang_code = max(probs, key=probs.get) + + return detected_lang_code + + def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]): + if output == "srt": + if CONFIG.HF_TOKEN != "": + WriteSRT(SubtitlesWriter).write_result(result, file=file, options={}) + else: + WriteSRT(ResultWriter).write_result(result, file=file, options={}) + elif output == "vtt": + if CONFIG.HF_TOKEN != "": + WriteVTT(SubtitlesWriter).write_result(result, file=file, options={}) + else: + WriteVTT(ResultWriter).write_result(result, file=file, options={}) + elif output == "tsv": + WriteTSV(ResultWriter).write_result(result, file=file, options={}) + elif output == "json": + WriteJSON(ResultWriter).write_result(result, file=file, options={}) + elif output == "txt": + WriteTXT(ResultWriter).write_result(result, file=file, options={}) + else: + return 'Please select an output method!' diff --git a/app/asr_models/openai_whisper_engine.py b/app/asr_models/openai_whisper_engine.py index c6c37dc..205aff8 100644 --- a/app/asr_models/openai_whisper_engine.py +++ b/app/asr_models/openai_whisper_engine.py @@ -16,32 +16,28 @@ class OpenAIWhisperASR(ASRModel): def load_model(self): if torch.cuda.is_available(): - self.model = whisper.load_model( - name=CONFIG.MODEL_NAME, - download_root=CONFIG.MODEL_PATH - ).cuda() + self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda() else: - self.model = whisper.load_model( - name=CONFIG.MODEL_NAME, - download_root=CONFIG.MODEL_PATH - ) + self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH) Thread(target=self.monitor_idleness, daemon=True).start() def transcribe( - self, - audio, - task: Union[str, None], - language: Union[str, None], - initial_prompt: Union[str, None], - vad_filter: Union[bool, None], - word_timestamps: Union[bool, None], - output, + self, + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + vad_filter: Union[bool, None], + word_timestamps: Union[bool, None], + options: Union[dict, None], + output, ): self.last_activity_time = time.time() with self.model_lock: - if self.model is None: self.load_model() + if self.model is None: + self.load_model() options_dict = {"task": task} if language: @@ -64,7 +60,8 @@ class OpenAIWhisperASR(ASRModel): self.last_activity_time = time.time() with self.model_lock: - if self.model is None: self.load_model() + if self.model is None: + self.load_model() # load audio and pad/trim it to fit 30 seconds audio = whisper.pad_or_trim(audio) diff --git a/app/config.py b/app/config.py index 6420269..31779aa 100644 --- a/app/config.py +++ b/app/config.py @@ -11,6 +11,9 @@ class CONFIG: # Determine the ASR engine ('faster_whisper' or 'openai_whisper') ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper") + # Retrieve Huggingface Token + HF_TOKEN = os.getenv("HF_TOKEN", "") + # Determine the computation device (GPU or CPU) DEVICE = os.getenv("ASR_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") diff --git a/app/factory/asr_model_factory.py b/app/factory/asr_model_factory.py index 0f97fb5..10588dd 100644 --- a/app/factory/asr_model_factory.py +++ b/app/factory/asr_model_factory.py @@ -1,6 +1,7 @@ from app.asr_models.asr_model import ASRModel from app.asr_models.faster_whisper_engine import FasterWhisperASR from app.asr_models.openai_whisper_engine import OpenAIWhisperASR +from app.asr_models.mbain_whisperx_engine import WhisperXASR from app.config import CONFIG @@ -11,5 +12,7 @@ class ASRModelFactory: return OpenAIWhisperASR() elif CONFIG.ASR_ENGINE == "faster_whisper": return FasterWhisperASR() + elif CONFIG.ASR_ENGINE == "whisperx": + return WhisperXASR() else: raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}") diff --git a/app/utils.py b/app/utils.py index 154dbbb..31c52fa 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,7 +1,7 @@ import json import os -from dataclasses import asdict -from typing import TextIO, BinaryIO +from dataclasses import asdict, is_dataclass +from typing import TextIO, BinaryIO, Union import ffmpeg import numpy as np @@ -23,14 +23,42 @@ class ResultWriter: with open(output_path, "w", encoding="utf-8") as f: self.write_result(result, file=f) - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: Union[dict, None]): raise NotImplementedError + + def format_segments_in_result(self, result: dict): + if "segments" in result: + # Check if result["segments"] is a list + if isinstance(result["segments"], list): + # Check if the list is empty + if not result["segments"]: + # Handle the empty list case, you can choose to leave it as is or set it to an empty list + pass + else: + # Check if the first item in the list is a dataclass instance + if is_dataclass(result["segments"][0]): + result["segments"] = [asdict(segment) for segment in result["segments"]] + # If it's already a list of dicts, leave it as is + elif isinstance(result["segments"][0], dict): + pass + else: + # Handle the case where the list contains neither dataclass instances nor dicts + # You can choose to leave it as is or raise an error + pass + elif isinstance(result["segments"], dict): + # If it's already a dict, leave it as is + pass + else: + # Handle the case where result["segments"] is neither a list nor a dict + # You can choose to leave it as is or raise an error + pass + return result class WriteTXT(ResultWriter): extension: str = "txt" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: Union[dict, None]): for segment in result["segments"]: print(segment.text.strip(), file=file, flush=True) @@ -38,12 +66,13 @@ class WriteTXT(ResultWriter): class WriteVTT(ResultWriter): extension: str = "vtt" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: Union[dict, None]): print("WEBVTT\n", file=file) + result = self.format_segments_in_result(result) for segment in result["segments"]: print( - f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n" - f"{segment.text.strip().replace('-->', '->')}\n", + f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", file=file, flush=True, ) @@ -52,14 +81,15 @@ class WriteVTT(ResultWriter): class WriteSRT(ResultWriter): extension: str = "srt" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: Union[dict, None]): + result = self.format_segments_in_result(result) for i, segment in enumerate(result["segments"], start=1): # write srt lines print( f"{i}\n" - f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> " - f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n" - f"{segment.text.strip().replace('-->', '->')}\n", + f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " + f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", file=file, flush=True, ) @@ -77,20 +107,20 @@ class WriteTSV(ResultWriter): extension: str = "tsv" - def write_result(self, result: dict, file: TextIO): + def write_result(self, result: dict, file: TextIO, options: Union[dict, None]): + result = self.format_segments_in_result(result) print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: - print(round(1000 * segment.start), file=file, end="\t") - print(round(1000 * segment.end), file=file, end="\t") - print(segment.text.strip().replace("\t", " "), file=file, flush=True) + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) class WriteJSON(ResultWriter): extension: str = "json" - def write_result(self, result: dict, file: TextIO): - if "segments" in result: - result["segments"] = [asdict(segment) for segment in result["segments"]] + def write_result(self, result: dict, file: TextIO, options: Union[dict, None]): + result = self.format_segments_in_result(result) json.dump(result, file) diff --git a/app/webservice.py b/app/webservice.py index 7158645..8f4fa6a 100644 --- a/app/webservice.py +++ b/app/webservice.py @@ -35,7 +35,6 @@ assets_path = os.getcwd() + "/swagger-ui-assets" if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"): app.mount("/assets", StaticFiles(directory=assets_path), name="static") - def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( *args, @@ -45,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/ swagger_js_url="/assets/swagger-ui-bundle.js", ) - applications.get_swagger_ui_html = swagger_monkey_patch @@ -56,23 +54,49 @@ async def index(): @app.post("/asr", tags=["Endpoints"]) async def asr( - audio_file: UploadFile = File(...), # noqa: B008 - encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), - task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), - language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES), - initial_prompt: Union[str, None] = Query(default=None), - vad_filter: Annotated[ - bool | None, - Query( - description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech", - include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False), - ), - ] = False, - word_timestamps: bool = Query(default=False, description="Word level timestamps"), - output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), + audio_file: UploadFile = File(...), # noqa: B008 + encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), + task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), + language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES), + initial_prompt: Union[str, None] = Query(default=None), + vad_filter: Annotated[ + bool | None, + Query( + description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech", + include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False), + ), + ] = False, + word_timestamps: bool = Query( + default=False, + description="Word level timestamps", + include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False), + ), + diarize: bool = Query( + default=False, + description="Diarize the input", + include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" and CONFIG.HF_TOKEN != "" else False), + ), + min_speakers: Union[int, None] = Query( + default=None, + description="Min speakers in this file", + include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False), + ), + max_speakers: Union[int, None] = Query( + default=None, + description="Max speakers in this file", + include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False), + ), + output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), ): result = asr_model.transcribe( - load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output + load_audio(audio_file.file, encode), + task, + language, + initial_prompt, + vad_filter, + word_timestamps, + {"diarize": diarize, "min_speakers": min_speakers, "max_speakers": max_speakers}, + output, ) return StreamingResponse( result, @@ -86,12 +110,15 @@ async def asr( @app.post("/detect-language", tags=["Endpoints"]) async def detect_language( - audio_file: UploadFile = File(...), # noqa: B008 - encode: bool = Query(default=True, description="Encode audio first through FFmpeg"), + audio_file: UploadFile = File(...), # noqa: B008 + encode: bool = Query(default=True, description="Encode audio first through FFmpeg"), ): detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode)) - return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code, - "confidence": confidence} + return { + "detected_language": tokenizer.LANGUAGES[detected_lang_code], + "language_code": detected_lang_code, + "confidence": confidence, + } @click.command() @@ -110,10 +137,7 @@ async def detect_language( help="Port for the webservice (default: 9000)", ) @click.version_option(version=projectMetadata["Version"]) -def start( - host: str, - port: Optional[int] = None -): +def start(host: str, port: Optional[int] = None): uvicorn.run(app, host=host, port=port) -- GitLab