From 0cd35b893d7beb2f715966d2bdf8fe3db40b72aa Mon Sep 17 00:00:00 2001
From: charnesp <charles.nespoulous@gmail.com>
Date: Sun, 16 Feb 2025 18:33:31 +0100
Subject: [PATCH] Integrate Whisperx (#267)

* Initial commit for whisperx

* whisperx working

* Correct Dockerfile

* Add possibility to prioritize self-hosted runners

---------

Co-authored-by: Charles N <charles@chouette.vision>
---
 .devcontainer/devcontainer.json         |  44 ++++++++++
 .devcontainer/docker-compose.yml        |  30 +++++++
 .github/workflows/docker-publish.yml    |   6 +-
 .gitignore                              |   4 +-
 Dockerfile                              |  13 +++
 Dockerfile.gpu                          |  12 +++
 README.md                               |   1 +
 app/asr_models/asr_model.py             |  27 ++++--
 app/asr_models/faster_whisper_engine.py |   1 +
 app/asr_models/mbain_whisperx_engine.py | 111 ++++++++++++++++++++++++
 app/asr_models/openai_whisper_engine.py |  33 ++++---
 app/config.py                           |   3 +
 app/factory/asr_model_factory.py        |   3 +
 app/utils.py                            |  66 ++++++++++----
 app/webservice.py                       |  74 ++++++++++------
 15 files changed, 356 insertions(+), 72 deletions(-)
 create mode 100644 .devcontainer/devcontainer.json
 create mode 100644 .devcontainer/docker-compose.yml
 create mode 100644 app/asr_models/mbain_whisperx_engine.py

diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 0000000..36e62bd
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,44 @@
+// For format details, see https://aka.ms/devcontainer.json. For config options, see the
+// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
+{
+	"name": "Existing Docker Compose (Extend)",
+
+	// Update the 'dockerComposeFile' list if you have more compose files or use different names.
+	// The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
+	"dockerComposeFile": [
+		"../docker-compose.yml",
+		"docker-compose.yml"
+	],
+
+	// The 'service' property is the name of the service for the container that VS Code should
+	// use. Update this value and .devcontainer/docker-compose.yml to the real service name.
+	"service": "whisper-asr-webservice",
+
+	// The optional 'workspaceFolder' property is the path VS Code should open by default when
+	// connected. This is typically a file mount in .devcontainer/docker-compose.yml
+	"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
+
+	// "overrideCommand": "/bin/sh -c 'while sleep 1000; do :; done'"
+	"overrideCommand": true
+
+	// Features to add to the dev container. More info: https://containers.dev/features.
+	// "features": {},
+
+	// Use 'forwardPorts' to make a list of ports inside the container available locally.
+	// "forwardPorts": [],
+
+	// Uncomment the next line if you want start specific services in your Docker Compose config.
+	// "runServices": [],
+
+	// Uncomment the next line if you want to keep your containers running after VS Code shuts down.
+	// "shutdownAction": "none",
+
+	// Uncomment the next line to run commands after the container is created.
+	// "postCreateCommand": "cat /etc/os-release",
+
+	// Configure tool-specific properties.
+	// "customizations": {},
+
+	// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
+	// "remoteUser": "devcontainer"
+}
diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml
new file mode 100644
index 0000000..c3e29c5
--- /dev/null
+++ b/.devcontainer/docker-compose.yml
@@ -0,0 +1,30 @@
+version: '3.4'
+services:
+  # Update this to the name of the service you want to work with in your docker-compose.yml file
+  whisper-asr-webservice:
+    # Uncomment if you want to override the service's Dockerfile to one in the .devcontainer 
+    # folder. Note that the path of the Dockerfile and context is relative to the *primary* 
+    # docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
+    # array). The sample below assumes your primary file is in the root of your project.
+    #
+    # build:
+    #   context: .
+    #   dockerfile: .devcontainer/Dockerfile
+    env_file: .devcontainer/dev.env
+    environment:
+      ASR_ENGINE: ${ASR_ENGINE}
+      HF_TOKEN: ${HF_TOKEN}
+
+    volumes:
+      # Update this to wherever you want VS Code to mount the folder of your project
+      - ..:/workspaces:cached
+
+    # Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
+    # cap_add:
+    #   - SYS_PTRACE
+    # security_opt:
+    #   - seccomp:unconfined
+
+    # Overrides default command so things don't shut down after the process ends.
+    command: sleep infinity
+ 
diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml
index 571d2b8..574ea12 100644
--- a/.github/workflows/docker-publish.yml
+++ b/.github/workflows/docker-publish.yml
@@ -12,7 +12,7 @@ env:
   REPO_NAME: ${{secrets.REPO_NAME}}
 jobs:
   build:
-    runs-on: ubuntu-latest
+    runs-on: [self-hosted, ubuntu-latest]
     strategy:
       matrix:
         include:
@@ -22,6 +22,10 @@ jobs:
             tag_extension: -gpu
             platforms: linux/amd64
     steps:
+    - name: Remove unnecessary files
+      run: |
+          sudo rm -rf /usr/share/dotnet
+          sudo rm -rf "$AGENT_TOOLSDIRECTORY"
     - name: Checkout
       uses: actions/checkout@v3
     - name: Set up QEMU
diff --git a/.gitignore b/.gitignore
index 4dbf939..3a43ec4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,4 +41,6 @@ pip-wheel-metadata
 
 poetry/core/*
 
-public
\ No newline at end of file
+public
+
+.devcontainer/dev.env
diff --git a/Dockerfile b/Dockerfile
index bdae06f..08cfa3a 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \
     pkg-config \
     yasm \
     ca-certificates \
+    gcc \
+    python3-dev \
     && rm -rf /var/lib/apt/lists/*
 
 RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1
@@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
 
 FROM python:3.10-bookworm
 
+RUN export DEBIAN_FRONTEND=noninteractive \
+    && apt-get -qq update \
+    && apt-get -qq install --no-install-recommends \
+    libsndfile1 \
+    && rm -rf /var/lib/apt/lists/*
+
 ENV POETRY_VENV=/app/.venv
 
 RUN python3 -m venv $POETRY_VENV \
@@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
 RUN poetry config virtualenvs.in-project true
 RUN poetry install
 
+RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
+RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
+    && cd whisperX \
+    && $POETRY_VENV/bin/pip install -e .
+
 EXPOSE 9000
 
 ENTRYPOINT ["whisper-asr-webservice"]
diff --git a/Dockerfile.gpu b/Dockerfile.gpu
index b605707..709681f 100644
--- a/Dockerfile.gpu
+++ b/Dockerfile.gpu
@@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
 FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
 
 ENV PYTHON_VERSION=3.10
+
+RUN export DEBIAN_FRONTEND=noninteractive \
+    && apt-get -qq update \
+    && apt-get -qq install --no-install-recommends \
+    libsndfile1 \
+    && rm -rf /var/lib/apt/lists/*
+
 ENV POETRY_VENV=/app/.venv
 
 RUN export DEBIAN_FRONTEND=noninteractive \
@@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
 RUN poetry install
 RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch
 
+RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
+RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
+    && cd whisperX \
+    && $POETRY_VENV/bin/pip install -e .
+
 EXPOSE 9000
 
 CMD whisper-asr-webservice
diff --git a/README.md b/README.md
index b521283..21dcb00 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models:
 
 - [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
 - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0)
+- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1)
 
 ## Quick Usage
 
diff --git a/app/asr_models/asr_model.py b/app/asr_models/asr_model.py
index fc22050..8ec9890 100644
--- a/app/asr_models/asr_model.py
+++ b/app/asr_models/asr_model.py
@@ -13,7 +13,10 @@ class ASRModel(ABC):
     """
     Abstract base class for ASR (Automatic Speech Recognition) models.
     """
+
     model = None
+    diarize_model = None  # used for WhisperX
+    x_models = dict()  # used for WhisperX
     model_lock = Lock()
     last_activity_time = time.time()
 
@@ -28,14 +31,17 @@ class ASRModel(ABC):
         pass
 
     @abstractmethod
-    def transcribe(self,
-                   audio,
-                   task: Union[str, None],
-                   language: Union[str, None],
-                   initial_prompt: Union[str, None],
-                   vad_filter: Union[bool, None],
-                   word_timestamps: Union[bool, None]
-                   ):
+    def transcribe(
+        self,
+        audio,
+        task: Union[str, None],
+        language: Union[str, None],
+        initial_prompt: Union[str, None],
+        vad_filter: Union[bool, None],
+        word_timestamps: Union[bool, None],
+        options: Union[dict, None],
+        output,
+    ):
         """
         Perform transcription on the given audio file.
         """
@@ -52,7 +58,8 @@ class ASRModel(ABC):
         """
         Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
         """
-        if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
+        if CONFIG.MODEL_IDLE_TIMEOUT <= 0:
+            return
         while True:
             time.sleep(15)
             if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
@@ -68,4 +75,6 @@ class ASRModel(ABC):
         torch.cuda.empty_cache()
         gc.collect()
         self.model = None
+        self.diarize_model = None
+        self.x_models = dict()
         print("Model unloaded due to timeout")
diff --git a/app/asr_models/faster_whisper_engine.py b/app/asr_models/faster_whisper_engine.py
index a0b8620..3cf21f8 100644
--- a/app/asr_models/faster_whisper_engine.py
+++ b/app/asr_models/faster_whisper_engine.py
@@ -32,6 +32,7 @@ class FasterWhisperASR(ASRModel):
             initial_prompt: Union[str, None],
             vad_filter: Union[bool, None],
             word_timestamps: Union[bool, None],
+            options: Union[dict, None],
             output,
     ):
         self.last_activity_time = time.time()
diff --git a/app/asr_models/mbain_whisperx_engine.py b/app/asr_models/mbain_whisperx_engine.py
new file mode 100644
index 0000000..0e104bf
--- /dev/null
+++ b/app/asr_models/mbain_whisperx_engine.py
@@ -0,0 +1,111 @@
+from typing import BinaryIO, Union
+from io import StringIO
+import whisperx
+import whisper
+from whisperx.utils import SubtitlesWriter, ResultWriter
+
+from app.asr_models.asr_model import ASRModel
+from app.config import CONFIG
+from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
+
+
+class WhisperXASR(ASRModel):
+    def __init__(self):
+        self.x_models = dict()
+
+    def load_model(self):
+
+        asr_options = {"without_timestamps": False}
+        self.model = whisperx.load_model(
+            CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options
+        )
+
+        if CONFIG.HF_TOKEN != "":
+            self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE)
+
+    def transcribe(
+        self,
+        audio,
+        task: Union[str, None],
+        language: Union[str, None],
+        initial_prompt: Union[str, None],
+        vad_filter: Union[bool, None],
+        word_timestamps: Union[bool, None],
+        options: Union[dict, None],
+        output,
+    ):
+        options_dict = {"task": task}
+        if language:
+            options_dict["language"] = language
+        if initial_prompt:
+            options_dict["initial_prompt"] = initial_prompt
+        with self.model_lock:
+            if self.model is None:
+                self.load_model()
+            result = self.model.transcribe(audio, **options_dict)
+
+        # Load the required model and cache it
+        # If we transcribe models in many different languages, this may lead to OOM propblems
+        if result["language"] in self.x_models:
+            model_x, metadata = self.x_models[result["language"]]
+        else:
+            self.x_models[result["language"]] = whisperx.load_align_model(
+                language_code=result["language"], device=CONFIG.DEVICE
+            )
+            model_x, metadata = self.x_models[result["language"]]
+
+        # Align whisper output
+        result = whisperx.align(
+            result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False
+        )
+
+        if options.get("diarize", False):
+            if CONFIG.HF_TOKEN == "":
+                print("Warning! HF_TOKEN is not set. Diarization may not work as expected.")
+            min_speakers = options.get("min_speakers", None)
+            max_speakers = options.get("max_speakers", None)
+            # add min/max number of speakers if known
+            diarize_segments = self.diarize_model(audio, min_speakers, max_speakers)
+            result = whisperx.assign_word_speakers(diarize_segments, result)
+
+        output_file = StringIO()
+        self.write_result(result, output_file, output)
+        output_file.seek(0)
+
+        return output_file
+
+    def language_detection(self, audio):
+        # load audio and pad/trim it to fit 30 seconds
+        audio = whisper.pad_or_trim(audio)
+
+        # make log-Mel spectrogram and move to the same device as the model
+        mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
+
+        # detect the spoken language
+        with self.model_lock:
+            if self.model is None:
+                self.load_model()
+            _, probs = self.model.detect_language(mel)
+        detected_lang_code = max(probs, key=probs.get)
+
+        return detected_lang_code
+
+    def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
+        if output == "srt":
+            if CONFIG.HF_TOKEN != "":
+                WriteSRT(SubtitlesWriter).write_result(result, file=file, options={})
+            else:
+                WriteSRT(ResultWriter).write_result(result, file=file, options={})
+        elif output == "vtt":
+            if CONFIG.HF_TOKEN != "":
+                WriteVTT(SubtitlesWriter).write_result(result, file=file, options={})
+            else:
+                WriteVTT(ResultWriter).write_result(result, file=file, options={})
+        elif output == "tsv":
+            WriteTSV(ResultWriter).write_result(result, file=file, options={})
+        elif output == "json":
+            WriteJSON(ResultWriter).write_result(result, file=file, options={})
+        elif output == "txt":
+            WriteTXT(ResultWriter).write_result(result, file=file, options={})
+        else:
+            return 'Please select an output method!'
diff --git a/app/asr_models/openai_whisper_engine.py b/app/asr_models/openai_whisper_engine.py
index c6c37dc..205aff8 100644
--- a/app/asr_models/openai_whisper_engine.py
+++ b/app/asr_models/openai_whisper_engine.py
@@ -16,32 +16,28 @@ class OpenAIWhisperASR(ASRModel):
     def load_model(self):
 
         if torch.cuda.is_available():
-            self.model = whisper.load_model(
-                name=CONFIG.MODEL_NAME,
-                download_root=CONFIG.MODEL_PATH
-            ).cuda()
+            self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda()
         else:
-            self.model = whisper.load_model(
-                name=CONFIG.MODEL_NAME,
-                download_root=CONFIG.MODEL_PATH
-            )
+            self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH)
 
         Thread(target=self.monitor_idleness, daemon=True).start()
 
     def transcribe(
-            self,
-            audio,
-            task: Union[str, None],
-            language: Union[str, None],
-            initial_prompt: Union[str, None],
-            vad_filter: Union[bool, None],
-            word_timestamps: Union[bool, None],
-            output,
+        self,
+        audio,
+        task: Union[str, None],
+        language: Union[str, None],
+        initial_prompt: Union[str, None],
+        vad_filter: Union[bool, None],
+        word_timestamps: Union[bool, None],
+        options: Union[dict, None],
+        output,
     ):
         self.last_activity_time = time.time()
 
         with self.model_lock:
-            if self.model is None: self.load_model()
+            if self.model is None:
+                self.load_model()
 
         options_dict = {"task": task}
         if language:
@@ -64,7 +60,8 @@ class OpenAIWhisperASR(ASRModel):
         self.last_activity_time = time.time()
 
         with self.model_lock:
-            if self.model is None: self.load_model()
+            if self.model is None:
+                self.load_model()
 
         # load audio and pad/trim it to fit 30 seconds
         audio = whisper.pad_or_trim(audio)
diff --git a/app/config.py b/app/config.py
index 6420269..31779aa 100644
--- a/app/config.py
+++ b/app/config.py
@@ -11,6 +11,9 @@ class CONFIG:
     # Determine the ASR engine ('faster_whisper' or 'openai_whisper')
     ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
 
+    # Retrieve Huggingface Token
+    HF_TOKEN = os.getenv("HF_TOKEN", "")
+
     # Determine the computation device (GPU or CPU)
     DEVICE = os.getenv("ASR_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
 
diff --git a/app/factory/asr_model_factory.py b/app/factory/asr_model_factory.py
index 0f97fb5..10588dd 100644
--- a/app/factory/asr_model_factory.py
+++ b/app/factory/asr_model_factory.py
@@ -1,6 +1,7 @@
 from app.asr_models.asr_model import ASRModel
 from app.asr_models.faster_whisper_engine import FasterWhisperASR
 from app.asr_models.openai_whisper_engine import OpenAIWhisperASR
+from app.asr_models.mbain_whisperx_engine import WhisperXASR
 from app.config import CONFIG
 
 
@@ -11,5 +12,7 @@ class ASRModelFactory:
             return OpenAIWhisperASR()
         elif CONFIG.ASR_ENGINE == "faster_whisper":
             return FasterWhisperASR()
+        elif CONFIG.ASR_ENGINE == "whisperx":
+            return WhisperXASR()
         else:
             raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}")
diff --git a/app/utils.py b/app/utils.py
index 154dbbb..31c52fa 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -1,7 +1,7 @@
 import json
 import os
-from dataclasses import asdict
-from typing import TextIO, BinaryIO
+from dataclasses import asdict, is_dataclass
+from typing import TextIO, BinaryIO, Union
 
 import ffmpeg
 import numpy as np
@@ -23,14 +23,42 @@ class ResultWriter:
         with open(output_path, "w", encoding="utf-8") as f:
             self.write_result(result, file=f)
 
-    def write_result(self, result: dict, file: TextIO):
+    def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
         raise NotImplementedError
+    
+    def format_segments_in_result(self, result: dict):
+        if "segments" in result:
+            # Check if result["segments"] is a list
+            if isinstance(result["segments"], list):
+                # Check if the list is empty
+                if not result["segments"]:
+                    # Handle the empty list case, you can choose to leave it as is or set it to an empty list
+                    pass
+                else:
+                    # Check if the first item in the list is a dataclass instance
+                    if is_dataclass(result["segments"][0]):
+                        result["segments"] = [asdict(segment) for segment in result["segments"]]
+                    # If it's already a list of dicts, leave it as is
+                    elif isinstance(result["segments"][0], dict):
+                        pass
+                    else:
+                        # Handle the case where the list contains neither dataclass instances nor dicts
+                        # You can choose to leave it as is or raise an error
+                        pass
+            elif isinstance(result["segments"], dict):
+                # If it's already a dict, leave it as is
+                pass
+            else:
+                # Handle the case where result["segments"] is neither a list nor a dict
+                # You can choose to leave it as is or raise an error
+                pass
+        return result
 
 
 class WriteTXT(ResultWriter):
     extension: str = "txt"
 
-    def write_result(self, result: dict, file: TextIO):
+    def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
         for segment in result["segments"]:
             print(segment.text.strip(), file=file, flush=True)
 
@@ -38,12 +66,13 @@ class WriteTXT(ResultWriter):
 class WriteVTT(ResultWriter):
     extension: str = "vtt"
 
-    def write_result(self, result: dict, file: TextIO):
+    def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
         print("WEBVTT\n", file=file)
+        result = self.format_segments_in_result(result)
         for segment in result["segments"]:
             print(
-                f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n"
-                f"{segment.text.strip().replace('-->', '->')}\n",
+                f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
+                f"{segment['text'].strip().replace('-->', '->')}\n",
                 file=file,
                 flush=True,
             )
@@ -52,14 +81,15 @@ class WriteVTT(ResultWriter):
 class WriteSRT(ResultWriter):
     extension: str = "srt"
 
-    def write_result(self, result: dict, file: TextIO):
+    def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
+        result = self.format_segments_in_result(result)
         for i, segment in enumerate(result["segments"], start=1):
             # write srt lines
             print(
                 f"{i}\n"
-                f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> "
-                f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n"
-                f"{segment.text.strip().replace('-->', '->')}\n",
+                f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
+                f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
+                f"{segment['text'].strip().replace('-->', '->')}\n",
                 file=file,
                 flush=True,
             )
@@ -77,20 +107,20 @@ class WriteTSV(ResultWriter):
 
     extension: str = "tsv"
 
-    def write_result(self, result: dict, file: TextIO):
+    def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
+        result = self.format_segments_in_result(result)
         print("start", "end", "text", sep="\t", file=file)
         for segment in result["segments"]:
-            print(round(1000 * segment.start), file=file, end="\t")
-            print(round(1000 * segment.end), file=file, end="\t")
-            print(segment.text.strip().replace("\t", " "), file=file, flush=True)
+            print(round(1000 * segment["start"]), file=file, end="\t")
+            print(round(1000 * segment["end"]), file=file, end="\t")
+            print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
 
 
 class WriteJSON(ResultWriter):
     extension: str = "json"
 
-    def write_result(self, result: dict, file: TextIO):
-        if "segments" in result:
-            result["segments"] = [asdict(segment) for segment in result["segments"]]
+    def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
+        result = self.format_segments_in_result(result) 
         json.dump(result, file)
 
 
diff --git a/app/webservice.py b/app/webservice.py
index 7158645..8f4fa6a 100644
--- a/app/webservice.py
+++ b/app/webservice.py
@@ -35,7 +35,6 @@ assets_path = os.getcwd() + "/swagger-ui-assets"
 if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
     app.mount("/assets", StaticFiles(directory=assets_path), name="static")
 
-
     def swagger_monkey_patch(*args, **kwargs):
         return get_swagger_ui_html(
             *args,
@@ -45,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
             swagger_js_url="/assets/swagger-ui-bundle.js",
         )
 
-
     applications.get_swagger_ui_html = swagger_monkey_patch
 
 
@@ -56,23 +54,49 @@ async def index():
 
 @app.post("/asr", tags=["Endpoints"])
 async def asr(
-        audio_file: UploadFile = File(...),  # noqa: B008
-        encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
-        task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
-        language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
-        initial_prompt: Union[str, None] = Query(default=None),
-        vad_filter: Annotated[
-            bool | None,
-            Query(
-                description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
-                include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
-            ),
-        ] = False,
-        word_timestamps: bool = Query(default=False, description="Word level timestamps"),
-        output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
+    audio_file: UploadFile = File(...),  # noqa: B008
+    encode: bool = Query(default=True, description="Encode audio first through ffmpeg"),
+    task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
+    language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
+    initial_prompt: Union[str, None] = Query(default=None),
+    vad_filter: Annotated[
+        bool | None,
+        Query(
+            description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
+            include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
+        ),
+    ] = False,
+    word_timestamps: bool = Query(
+        default=False,
+        description="Word level timestamps",
+        include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
+    ),
+    diarize: bool = Query(
+        default=False,
+        description="Diarize the input",
+        include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" and CONFIG.HF_TOKEN != "" else False),
+    ),
+    min_speakers: Union[int, None] = Query(
+        default=None,
+        description="Min speakers in this file",
+        include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
+    ),
+    max_speakers: Union[int, None] = Query(
+        default=None,
+        description="Max speakers in this file",
+        include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
+    ),
+    output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
 ):
     result = asr_model.transcribe(
-        load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
+        load_audio(audio_file.file, encode),
+        task,
+        language,
+        initial_prompt,
+        vad_filter,
+        word_timestamps,
+        {"diarize": diarize, "min_speakers": min_speakers, "max_speakers": max_speakers},
+        output,
     )
     return StreamingResponse(
         result,
@@ -86,12 +110,15 @@ async def asr(
 
 @app.post("/detect-language", tags=["Endpoints"])
 async def detect_language(
-        audio_file: UploadFile = File(...),  # noqa: B008
-        encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
+    audio_file: UploadFile = File(...),  # noqa: B008
+    encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
 ):
     detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode))
-    return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code,
-            "confidence": confidence}
+    return {
+        "detected_language": tokenizer.LANGUAGES[detected_lang_code],
+        "language_code": detected_lang_code,
+        "confidence": confidence,
+    }
 
 
 @click.command()
@@ -110,10 +137,7 @@ async def detect_language(
     help="Port for the webservice (default: 9000)",
 )
 @click.version_option(version=projectMetadata["Version"])
-def start(
-        host: str,
-        port: Optional[int] = None
-):
+def start(host: str, port: Optional[int] = None):
     uvicorn.run(app, host=host, port=port)
 
 
-- 
GitLab