diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000000000000000000000000000000000000..36e62bdd0942c2ea7f7af0728d7e7bf35de7f363
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,44 @@
+// For format details, see https://aka.ms/devcontainer.json. For config options, see the
+// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
+{
+ "name": "Existing Docker Compose (Extend)",
+
+ // Update the 'dockerComposeFile' list if you have more compose files or use different names.
+ // The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
+ "dockerComposeFile": [
+ "../docker-compose.yml",
+ "docker-compose.yml"
+ ],
+
+ // The 'service' property is the name of the service for the container that VS Code should
+ // use. Update this value and .devcontainer/docker-compose.yml to the real service name.
+ "service": "whisper-asr-webservice",
+
+ // The optional 'workspaceFolder' property is the path VS Code should open by default when
+ // connected. This is typically a file mount in .devcontainer/docker-compose.yml
+ "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
+
+ // "overrideCommand": "/bin/sh -c 'while sleep 1000; do :; done'"
+ "overrideCommand": true
+
+ // Features to add to the dev container. More info: https://containers.dev/features.
+ // "features": {},
+
+ // Use 'forwardPorts' to make a list of ports inside the container available locally.
+ // "forwardPorts": [],
+
+ // Uncomment the next line if you want start specific services in your Docker Compose config.
+ // "runServices": [],
+
+ // Uncomment the next line if you want to keep your containers running after VS Code shuts down.
+ // "shutdownAction": "none",
+
+ // Uncomment the next line to run commands after the container is created.
+ // "postCreateCommand": "cat /etc/os-release",
+
+ // Configure tool-specific properties.
+ // "customizations": {},
+
+ // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
+ // "remoteUser": "devcontainer"
+}
diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c3e29c5e06776f94279b8abf25a044dfa2bd00f7
--- /dev/null
+++ b/.devcontainer/docker-compose.yml
@@ -0,0 +1,30 @@
+version: '3.4'
+services:
+ # Update this to the name of the service you want to work with in your docker-compose.yml file
+ whisper-asr-webservice:
+ # Uncomment if you want to override the service's Dockerfile to one in the .devcontainer
+ # folder. Note that the path of the Dockerfile and context is relative to the *primary*
+ # docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
+ # array). The sample below assumes your primary file is in the root of your project.
+ #
+ # build:
+ # context: .
+ # dockerfile: .devcontainer/Dockerfile
+ env_file: .devcontainer/dev.env
+ environment:
+ ASR_ENGINE: ${ASR_ENGINE}
+ HF_TOKEN: ${HF_TOKEN}
+
+ volumes:
+ # Update this to wherever you want VS Code to mount the folder of your project
+ - ..:/workspaces:cached
+
+ # Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
+ # cap_add:
+ # - SYS_PTRACE
+ # security_opt:
+ # - seccomp:unconfined
+
+ # Overrides default command so things don't shut down after the process ends.
+ command: sleep infinity
+
diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml
index 571d2b8dd956c95e2ba76d5a77141ea7f188e25a..574ea12319ae7ea5771ef380c2bdfdfb6adaa31e 100644
--- a/.github/workflows/docker-publish.yml
+++ b/.github/workflows/docker-publish.yml
@@ -12,7 +12,7 @@ env:
REPO_NAME: ${{secrets.REPO_NAME}}
jobs:
build:
- runs-on: ubuntu-latest
+ runs-on: [self-hosted, ubuntu-latest]
strategy:
matrix:
include:
@@ -22,6 +22,10 @@ jobs:
tag_extension: -gpu
platforms: linux/amd64
steps:
+ - name: Remove unnecessary files
+ run: |
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Checkout
uses: actions/checkout@v3
- name: Set up QEMU
diff --git a/.gitignore b/.gitignore
index 4dbf939ed64792ff9f11539e92510cb55e62111f..3a43ec4111fae3e500b9474f717b689d3d77dfbd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,4 +41,6 @@ pip-wheel-metadata
poetry/core/*
-public
\ No newline at end of file
+public
+
+.devcontainer/dev.env
diff --git a/Dockerfile b/Dockerfile
index bdae06f38c89a63ac1766658e277e868adfe89c2..08cfa3a9a930ba718be15852c700ea1af9b661bd 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \
pkg-config \
yasm \
ca-certificates \
+ gcc \
+ python3-dev \
&& rm -rf /var/lib/apt/lists/*
RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1
@@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
FROM python:3.10-bookworm
+RUN export DEBIAN_FRONTEND=noninteractive \
+ && apt-get -qq update \
+ && apt-get -qq install --no-install-recommends \
+ libsndfile1 \
+ && rm -rf /var/lib/apt/lists/*
+
ENV POETRY_VENV=/app/.venv
RUN python3 -m venv $POETRY_VENV \
@@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
RUN poetry config virtualenvs.in-project true
RUN poetry install
+RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
+RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
+ && cd whisperX \
+ && $POETRY_VENV/bin/pip install -e .
+
EXPOSE 9000
ENTRYPOINT ["whisper-asr-webservice"]
diff --git a/Dockerfile.gpu b/Dockerfile.gpu
index b605707d45663f6edee906d86e1cd55cd78e2c09..709681f3da6ac3357a924d660253d2e807338a0d 100644
--- a/Dockerfile.gpu
+++ b/Dockerfile.gpu
@@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
ENV PYTHON_VERSION=3.10
+
+RUN export DEBIAN_FRONTEND=noninteractive \
+ && apt-get -qq update \
+ && apt-get -qq install --no-install-recommends \
+ libsndfile1 \
+ && rm -rf /var/lib/apt/lists/*
+
ENV POETRY_VENV=/app/.venv
RUN export DEBIAN_FRONTEND=noninteractive \
@@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
RUN poetry install
RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch
+RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
+RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
+ && cd whisperX \
+ && $POETRY_VENV/bin/pip install -e .
+
EXPOSE 9000
CMD whisper-asr-webservice
diff --git a/README.md b/README.md
index b5212835417a04f35bd3969d18ebdf43b7fc0ba7..21dcb00f1e5d8b9df8761be65c5b0e7c331c27f7 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0)
+- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1)
## Quick Usage
diff --git a/app/asr_models/asr_model.py b/app/asr_models/asr_model.py
index fc2205083ed48e5e88267aeb3afa1de2d2e61051..8ec9890cbf9e02aa08b757fb2bf453a56eaaa92d 100644
--- a/app/asr_models/asr_model.py
+++ b/app/asr_models/asr_model.py
@@ -13,7 +13,10 @@ class ASRModel(ABC):
"""
Abstract base class for ASR (Automatic Speech Recognition) models.
"""
+
model = None
+ diarize_model = None # used for WhisperX
+ x_models = dict() # used for WhisperX
model_lock = Lock()
last_activity_time = time.time()
@@ -28,14 +31,17 @@ class ASRModel(ABC):
pass
@abstractmethod
- def transcribe(self,
- audio,
- task: Union[str, None],
- language: Union[str, None],
- initial_prompt: Union[str, None],
- vad_filter: Union[bool, None],
- word_timestamps: Union[bool, None]
- ):
+ def transcribe(
+ self,
+ audio,
+ task: Union[str, None],
+ language: Union[str, None],
+ initial_prompt: Union[str, None],
+ vad_filter: Union[bool, None],
+ word_timestamps: Union[bool, None],
+ options: Union[dict, None],
+ output,
+ ):
"""
Perform transcription on the given audio file.
"""
@@ -52,7 +58,8 @@ class ASRModel(ABC):
"""
Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
"""
- if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
+ if CONFIG.MODEL_IDLE_TIMEOUT <= 0:
+ return
while True:
time.sleep(15)
if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
@@ -68,4 +75,6 @@ class ASRModel(ABC):
torch.cuda.empty_cache()
gc.collect()
self.model = None
+ self.diarize_model = None
+ self.x_models = dict()
print("Model unloaded due to timeout")
diff --git a/app/asr_models/faster_whisper_engine.py b/app/asr_models/faster_whisper_engine.py
index a0b8620281373058002b0d893d5a00ddbe4b2ef4..3cf21f8e92a71005777ad529e8c777350d143cf7 100644
--- a/app/asr_models/faster_whisper_engine.py
+++ b/app/asr_models/faster_whisper_engine.py
@@ -32,6 +32,7 @@ class FasterWhisperASR(ASRModel):
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
+ options: Union[dict, None],
output,
):
self.last_activity_time = time.time()
diff --git a/app/asr_models/mbain_whisperx_engine.py b/app/asr_models/mbain_whisperx_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e104bf6298e36e0601cd3d8ffeb2c425f0d8c02
--- /dev/null
+++ b/app/asr_models/mbain_whisperx_engine.py
@@ -0,0 +1,111 @@
+from typing import BinaryIO, Union
+from io import StringIO
+import whisperx
+import whisper
+from whisperx.utils import SubtitlesWriter, ResultWriter
+
+from app.asr_models.asr_model import ASRModel
+from app.config import CONFIG
+from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
+
+
+class WhisperXASR(ASRModel):
+ def __init__(self):
+ self.x_models = dict()
+
+ def load_model(self):
+
+ asr_options = {"without_timestamps": False}
+ self.model = whisperx.load_model(
+ CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options
+ )
+
+ if CONFIG.HF_TOKEN != "":
+ self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE)
+
+ def transcribe(
+ self,
+ audio,
+ task: Union[str, None],
+ language: Union[str, None],
+ initial_prompt: Union[str, None],
+ vad_filter: Union[bool, None],
+ word_timestamps: Union[bool, None],
+ options: Union[dict, None],
+ output,
+ ):
+ options_dict = {"task": task}
+ if language:
+ options_dict["language"] = language
+ if initial_prompt:
+ options_dict["initial_prompt"] = initial_prompt
+ with self.model_lock:
+ if self.model is None:
+ self.load_model()
+ result = self.model.transcribe(audio, **options_dict)
+
+ # Load the required model and cache it
+ # If we transcribe models in many different languages, this may lead to OOM propblems
+ if result["language"] in self.x_models:
+ model_x, metadata = self.x_models[result["language"]]
+ else:
+ self.x_models[result["language"]] = whisperx.load_align_model(
+ language_code=result["language"], device=CONFIG.DEVICE
+ )
+ model_x, metadata = self.x_models[result["language"]]
+
+ # Align whisper output
+ result = whisperx.align(
+ result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False
+ )
+
+ if options.get("diarize", False):
+ if CONFIG.HF_TOKEN == "":
+ print("Warning! HF_TOKEN is not set. Diarization may not work as expected.")
+ min_speakers = options.get("min_speakers", None)
+ max_speakers = options.get("max_speakers", None)
+ # add min/max number of speakers if known
+ diarize_segments = self.diarize_model(audio, min_speakers, max_speakers)
+ result = whisperx.assign_word_speakers(diarize_segments, result)
+
+ output_file = StringIO()
+ self.write_result(result, output_file, output)
+ output_file.seek(0)
+
+ return output_file
+
+ def language_detection(self, audio):
+ # load audio and pad/trim it to fit 30 seconds
+ audio = whisper.pad_or_trim(audio)
+
+ # make log-Mel spectrogram and move to the same device as the model
+ mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
+
+ # detect the spoken language
+ with self.model_lock:
+ if self.model is None:
+ self.load_model()
+ _, probs = self.model.detect_language(mel)
+ detected_lang_code = max(probs, key=probs.get)
+
+ return detected_lang_code
+
+ def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
+ if output == "srt":
+ if CONFIG.HF_TOKEN != "":
+ WriteSRT(SubtitlesWriter).write_result(result, file=file, options={})
+ else:
+ WriteSRT(ResultWriter).write_result(result, file=file, options={})
+ elif output == "vtt":
+ if CONFIG.HF_TOKEN != "":
+ WriteVTT(SubtitlesWriter).write_result(result, file=file, options={})
+ else:
+ WriteVTT(ResultWriter).write_result(result, file=file, options={})
+ elif output == "tsv":
+ WriteTSV(ResultWriter).write_result(result, file=file, options={})
+ elif output == "json":
+ WriteJSON(ResultWriter).write_result(result, file=file, options={})
+ elif output == "txt":
+ WriteTXT(ResultWriter).write_result(result, file=file, options={})
+ else:
+ return 'Please select an output method!'
diff --git a/app/asr_models/openai_whisper_engine.py b/app/asr_models/openai_whisper_engine.py
index c6c37dc9346ec4380d12507259f87b41fdea695e..205aff89080db8646e2361a2d9c46f6be4bcfe52 100644
--- a/app/asr_models/openai_whisper_engine.py
+++ b/app/asr_models/openai_whisper_engine.py
@@ -16,32 +16,28 @@ class OpenAIWhisperASR(ASRModel):
def load_model(self):
if torch.cuda.is_available():
- self.model = whisper.load_model(
- name=CONFIG.MODEL_NAME,
- download_root=CONFIG.MODEL_PATH
- ).cuda()
+ self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda()
else:
- self.model = whisper.load_model(
- name=CONFIG.MODEL_NAME,
- download_root=CONFIG.MODEL_PATH
- )
+ self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH)
Thread(target=self.monitor_idleness, daemon=True).start()
def transcribe(
- self,
- audio,
- task: Union[str, None],
- language: Union[str, None],
- initial_prompt: Union[str, None],
- vad_filter: Union[bool, None],
- word_timestamps: Union[bool, None],
- output,
+ self,
+ audio,
+ task: Union[str, None],
+ language: Union[str, None],
+ initial_prompt: Union[str, None],
+ vad_filter: Union[bool, None],
+ word_timestamps: Union[bool, None],
+ options: Union[dict, None],
+ output,
):
self.last_activity_time = time.time()
with self.model_lock:
- if self.model is None: self.load_model()
+ if self.model is None:
+ self.load_model()
options_dict = {"task": task}
if language:
@@ -64,7 +60,8 @@ class OpenAIWhisperASR(ASRModel):
self.last_activity_time = time.time()
with self.model_lock:
- if self.model is None: self.load_model()
+ if self.model is None:
+ self.load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
diff --git a/app/config.py b/app/config.py
index 64202697feb58f622ddf3f8a0270cc38ba851262..31779aa41621786e54cf9dbb6dceb6a6f32860c2 100644
--- a/app/config.py
+++ b/app/config.py
@@ -11,6 +11,9 @@ class CONFIG:
# Determine the ASR engine ('faster_whisper' or 'openai_whisper')
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
+ # Retrieve Huggingface Token
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
+
# Determine the computation device (GPU or CPU)
DEVICE = os.getenv("ASR_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
diff --git a/app/factory/asr_model_factory.py b/app/factory/asr_model_factory.py
index 0f97fb53ae0a5ca20caec09dc40c388efcdb863a..10588dd248a047a11ab052bdbc1f43c32cfda387 100644
--- a/app/factory/asr_model_factory.py
+++ b/app/factory/asr_model_factory.py
@@ -1,6 +1,7 @@
from app.asr_models.asr_model import ASRModel
from app.asr_models.faster_whisper_engine import FasterWhisperASR
from app.asr_models.openai_whisper_engine import OpenAIWhisperASR
+from app.asr_models.mbain_whisperx_engine import WhisperXASR
from app.config import CONFIG
@@ -11,5 +12,7 @@ class ASRModelFactory:
return OpenAIWhisperASR()
elif CONFIG.ASR_ENGINE == "faster_whisper":
return FasterWhisperASR()
+ elif CONFIG.ASR_ENGINE == "whisperx":
+ return WhisperXASR()
else:
raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}")
diff --git a/app/utils.py b/app/utils.py
index 154dbbb0129e1767fda9f337cbb1e8fe91c4b6e4..31c52fad18a4c9fa4ecadfae2497eb233b95b752 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -1,7 +1,7 @@
import json
import os
-from dataclasses import asdict
-from typing import TextIO, BinaryIO
+from dataclasses import asdict, is_dataclass
+from typing import TextIO, BinaryIO, Union
import ffmpeg
import numpy as np
@@ -23,14 +23,42 @@ class ResultWriter:
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
raise NotImplementedError
+
+ def format_segments_in_result(self, result: dict):
+ if "segments" in result:
+ # Check if result["segments"] is a list
+ if isinstance(result["segments"], list):
+ # Check if the list is empty
+ if not result["segments"]:
+ # Handle the empty list case, you can choose to leave it as is or set it to an empty list
+ pass
+ else:
+ # Check if the first item in the list is a dataclass instance
+ if is_dataclass(result["segments"][0]):
+ result["segments"] = [asdict(segment) for segment in result["segments"]]
+ # If it's already a list of dicts, leave it as is
+ elif isinstance(result["segments"][0], dict):
+ pass
+ else:
+ # Handle the case where the list contains neither dataclass instances nor dicts
+ # You can choose to leave it as is or raise an error
+ pass
+ elif isinstance(result["segments"], dict):
+ # If it's already a dict, leave it as is
+ pass
+ else:
+ # Handle the case where result["segments"] is neither a list nor a dict
+ # You can choose to leave it as is or raise an error
+ pass
+ return result
class WriteTXT(ResultWriter):
extension: str = "txt"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
for segment in result["segments"]:
print(segment.text.strip(), file=file, flush=True)
@@ -38,12 +66,13 @@ class WriteTXT(ResultWriter):
class WriteVTT(ResultWriter):
extension: str = "vtt"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
print("WEBVTT\n", file=file)
+ result = self.format_segments_in_result(result)
for segment in result["segments"]:
print(
- f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n"
- f"{segment.text.strip().replace('-->', '->')}\n",
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
+ f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
@@ -52,14 +81,15 @@ class WriteVTT(ResultWriter):
class WriteSRT(ResultWriter):
extension: str = "srt"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
+ result = self.format_segments_in_result(result)
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
- f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> "
- f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n"
- f"{segment.text.strip().replace('-->', '->')}\n",
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
+ f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
@@ -77,20 +107,20 @@ class WriteTSV(ResultWriter):
extension: str = "tsv"
- def write_result(self, result: dict, file: TextIO):
+ def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
+ result = self.format_segments_in_result(result)
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
- print(round(1000 * segment.start), file=file, end="\t")
- print(round(1000 * segment.end), file=file, end="\t")
- print(segment.text.strip().replace("\t", " "), file=file, flush=True)
+ print(round(1000 * segment["start"]), file=file, end="\t")
+ print(round(1000 * segment["end"]), file=file, end="\t")
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
- def write_result(self, result: dict, file: TextIO):
- if "segments" in result:
- result["segments"] = [asdict(segment) for segment in result["segments"]]
+ def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
+ result = self.format_segments_in_result(result)
json.dump(result, file)
diff --git a/app/webservice.py b/app/webservice.py
index 7158645a467702a07baf8b997ce825712863ccf1..8f4fa6a3f65dc714b9dad4a122d8a0dadc816b3d 100644
--- a/app/webservice.py
+++ b/app/webservice.py
@@ -35,7 +35,6 @@ assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
app.mount("/assets", StaticFiles(directory=assets_path), name="static")
-
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args,
@@ -45,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
swagger_js_url="/assets/swagger-ui-bundle.js",
)
-
applications.get_swagger_ui_html = swagger_monkey_patch
@@ -56,23 +54,49 @@ async def index():
@app.post("/asr", tags=["Endpoints"])
async def asr(
- audio_file: UploadFile = File(...), # noqa: B008
- encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
- task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
- language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
- initial_prompt: Union[str, None] = Query(default=None),
- vad_filter: Annotated[
- bool | None,
- Query(
- description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
- include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
- ),
- ] = False,
- word_timestamps: bool = Query(default=False, description="Word level timestamps"),
- output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
+ audio_file: UploadFile = File(...), # noqa: B008
+ encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
+ task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
+ language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
+ initial_prompt: Union[str, None] = Query(default=None),
+ vad_filter: Annotated[
+ bool | None,
+ Query(
+ description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
+ include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
+ ),
+ ] = False,
+ word_timestamps: bool = Query(
+ default=False,
+ description="Word level timestamps",
+ include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
+ ),
+ diarize: bool = Query(
+ default=False,
+ description="Diarize the input",
+ include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" and CONFIG.HF_TOKEN != "" else False),
+ ),
+ min_speakers: Union[int, None] = Query(
+ default=None,
+ description="Min speakers in this file",
+ include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
+ ),
+ max_speakers: Union[int, None] = Query(
+ default=None,
+ description="Max speakers in this file",
+ include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
+ ),
+ output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
):
result = asr_model.transcribe(
- load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
+ load_audio(audio_file.file, encode),
+ task,
+ language,
+ initial_prompt,
+ vad_filter,
+ word_timestamps,
+ {"diarize": diarize, "min_speakers": min_speakers, "max_speakers": max_speakers},
+ output,
)
return StreamingResponse(
result,
@@ -86,12 +110,15 @@ async def asr(
@app.post("/detect-language", tags=["Endpoints"])
async def detect_language(
- audio_file: UploadFile = File(...), # noqa: B008
- encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
+ audio_file: UploadFile = File(...), # noqa: B008
+ encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
):
detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode))
- return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code,
- "confidence": confidence}
+ return {
+ "detected_language": tokenizer.LANGUAGES[detected_lang_code],
+ "language_code": detected_lang_code,
+ "confidence": confidence,
+ }
@click.command()
@@ -110,10 +137,7 @@ async def detect_language(
help="Port for the webservice (default: 9000)",
)
@click.version_option(version=projectMetadata["Version"])
-def start(
- host: str,
- port: Optional[int] = None
-):
+def start(host: str, port: Optional[int] = None):
uvicorn.run(app, host=host, port=port)