Skip to content
Snippets Groups Projects
Unverified Commit 0cd35b89 authored by charnesp's avatar charnesp Committed by GitHub
Browse files

Integrate Whisperx (#267)


* Initial commit for whisperx

* whisperx working

* Correct Dockerfile

* Add possibility to prioritize self-hosted runners

---------

Co-authored-by: default avatarCharles N <charles@chouette.vision>
parent f2f39fe1
No related branches found
No related tags found
No related merge requests found
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
{
"name": "Existing Docker Compose (Extend)",
// Update the 'dockerComposeFile' list if you have more compose files or use different names.
// The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
"dockerComposeFile": [
"../docker-compose.yml",
"docker-compose.yml"
],
// The 'service' property is the name of the service for the container that VS Code should
// use. Update this value and .devcontainer/docker-compose.yml to the real service name.
"service": "whisper-asr-webservice",
// The optional 'workspaceFolder' property is the path VS Code should open by default when
// connected. This is typically a file mount in .devcontainer/docker-compose.yml
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
// "overrideCommand": "/bin/sh -c 'while sleep 1000; do :; done'"
"overrideCommand": true
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Uncomment the next line if you want start specific services in your Docker Compose config.
// "runServices": [],
// Uncomment the next line if you want to keep your containers running after VS Code shuts down.
// "shutdownAction": "none",
// Uncomment the next line to run commands after the container is created.
// "postCreateCommand": "cat /etc/os-release",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
}
version: '3.4'
services:
# Update this to the name of the service you want to work with in your docker-compose.yml file
whisper-asr-webservice:
# Uncomment if you want to override the service's Dockerfile to one in the .devcontainer
# folder. Note that the path of the Dockerfile and context is relative to the *primary*
# docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
# array). The sample below assumes your primary file is in the root of your project.
#
# build:
# context: .
# dockerfile: .devcontainer/Dockerfile
env_file: .devcontainer/dev.env
environment:
ASR_ENGINE: ${ASR_ENGINE}
HF_TOKEN: ${HF_TOKEN}
volumes:
# Update this to wherever you want VS Code to mount the folder of your project
- ..:/workspaces:cached
# Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
# cap_add:
# - SYS_PTRACE
# security_opt:
# - seccomp:unconfined
# Overrides default command so things don't shut down after the process ends.
command: sleep infinity
......@@ -12,7 +12,7 @@ env:
REPO_NAME: ${{secrets.REPO_NAME}}
jobs:
build:
runs-on: ubuntu-latest
runs-on: [self-hosted, ubuntu-latest]
strategy:
matrix:
include:
......@@ -22,6 +22,10 @@ jobs:
tag_extension: -gpu
platforms: linux/amd64
steps:
- name: Remove unnecessary files
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Checkout
uses: actions/checkout@v3
- name: Set up QEMU
......
......
......@@ -42,3 +42,5 @@ pip-wheel-metadata
poetry/core/*
public
.devcontainer/dev.env
......@@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \
pkg-config \
yasm \
ca-certificates \
gcc \
python3-dev \
&& rm -rf /var/lib/apt/lists/*
RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1
......@@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
FROM python:3.10-bookworm
RUN export DEBIAN_FRONTEND=noninteractive \
&& apt-get -qq update \
&& apt-get -qq install --no-install-recommends \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/*
ENV POETRY_VENV=/app/.venv
RUN python3 -m venv $POETRY_VENV \
......@@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
RUN poetry config virtualenvs.in-project true
RUN poetry install
RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
&& cd whisperX \
&& $POETRY_VENV/bin/pip install -e .
EXPOSE 9000
ENTRYPOINT ["whisper-asr-webservice"]
......@@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
ENV PYTHON_VERSION=3.10
RUN export DEBIAN_FRONTEND=noninteractive \
&& apt-get -qq update \
&& apt-get -qq install --no-install-recommends \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/*
ENV POETRY_VENV=/app/.venv
RUN export DEBIAN_FRONTEND=noninteractive \
......@@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
RUN poetry install
RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch
RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
&& cd whisperX \
&& $POETRY_VENV/bin/pip install -e .
EXPOSE 9000
CMD whisper-asr-webservice
......@@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0)
- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1)
## Quick Usage
......
......
......@@ -13,7 +13,10 @@ class ASRModel(ABC):
"""
Abstract base class for ASR (Automatic Speech Recognition) models.
"""
model = None
diarize_model = None # used for WhisperX
x_models = dict() # used for WhisperX
model_lock = Lock()
last_activity_time = time.time()
......@@ -28,13 +31,16 @@ class ASRModel(ABC):
pass
@abstractmethod
def transcribe(self,
def transcribe(
self,
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None]
word_timestamps: Union[bool, None],
options: Union[dict, None],
output,
):
"""
Perform transcription on the given audio file.
......@@ -52,7 +58,8 @@ class ASRModel(ABC):
"""
Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
"""
if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
if CONFIG.MODEL_IDLE_TIMEOUT <= 0:
return
while True:
time.sleep(15)
if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
......@@ -68,4 +75,6 @@ class ASRModel(ABC):
torch.cuda.empty_cache()
gc.collect()
self.model = None
self.diarize_model = None
self.x_models = dict()
print("Model unloaded due to timeout")
......@@ -32,6 +32,7 @@ class FasterWhisperASR(ASRModel):
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
options: Union[dict, None],
output,
):
self.last_activity_time = time.time()
......
......
from typing import BinaryIO, Union
from io import StringIO
import whisperx
import whisper
from whisperx.utils import SubtitlesWriter, ResultWriter
from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
class WhisperXASR(ASRModel):
def __init__(self):
self.x_models = dict()
def load_model(self):
asr_options = {"without_timestamps": False}
self.model = whisperx.load_model(
CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options
)
if CONFIG.HF_TOKEN != "":
self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE)
def transcribe(
self,
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
options: Union[dict, None],
output,
):
options_dict = {"task": task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
with self.model_lock:
if self.model is None:
self.load_model()
result = self.model.transcribe(audio, **options_dict)
# Load the required model and cache it
# If we transcribe models in many different languages, this may lead to OOM propblems
if result["language"] in self.x_models:
model_x, metadata = self.x_models[result["language"]]
else:
self.x_models[result["language"]] = whisperx.load_align_model(
language_code=result["language"], device=CONFIG.DEVICE
)
model_x, metadata = self.x_models[result["language"]]
# Align whisper output
result = whisperx.align(
result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False
)
if options.get("diarize", False):
if CONFIG.HF_TOKEN == "":
print("Warning! HF_TOKEN is not set. Diarization may not work as expected.")
min_speakers = options.get("min_speakers", None)
max_speakers = options.get("max_speakers", None)
# add min/max number of speakers if known
diarize_segments = self.diarize_model(audio, min_speakers, max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)
return output_file
def language_detection(self, audio):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
# detect the spoken language
with self.model_lock:
if self.model is None:
self.load_model()
_, probs = self.model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get)
return detected_lang_code
def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
if output == "srt":
if CONFIG.HF_TOKEN != "":
WriteSRT(SubtitlesWriter).write_result(result, file=file, options={})
else:
WriteSRT(ResultWriter).write_result(result, file=file, options={})
elif output == "vtt":
if CONFIG.HF_TOKEN != "":
WriteVTT(SubtitlesWriter).write_result(result, file=file, options={})
else:
WriteVTT(ResultWriter).write_result(result, file=file, options={})
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file, options={})
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file, options={})
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file, options={})
else:
return 'Please select an output method!'
......@@ -16,15 +16,9 @@ class OpenAIWhisperASR(ASRModel):
def load_model(self):
if torch.cuda.is_available():
self.model = whisper.load_model(
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
).cuda()
self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda()
else:
self.model = whisper.load_model(
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
)
self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH)
Thread(target=self.monitor_idleness, daemon=True).start()
......@@ -36,12 +30,14 @@ class OpenAIWhisperASR(ASRModel):
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
options: Union[dict, None],
output,
):
self.last_activity_time = time.time()
with self.model_lock:
if self.model is None: self.load_model()
if self.model is None:
self.load_model()
options_dict = {"task": task}
if language:
......@@ -64,7 +60,8 @@ class OpenAIWhisperASR(ASRModel):
self.last_activity_time = time.time()
with self.model_lock:
if self.model is None: self.load_model()
if self.model is None:
self.load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
......
......
......@@ -11,6 +11,9 @@ class CONFIG:
# Determine the ASR engine ('faster_whisper' or 'openai_whisper')
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
# Retrieve Huggingface Token
HF_TOKEN = os.getenv("HF_TOKEN", "")
# Determine the computation device (GPU or CPU)
DEVICE = os.getenv("ASR_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
......
......
from app.asr_models.asr_model import ASRModel
from app.asr_models.faster_whisper_engine import FasterWhisperASR
from app.asr_models.openai_whisper_engine import OpenAIWhisperASR
from app.asr_models.mbain_whisperx_engine import WhisperXASR
from app.config import CONFIG
......@@ -11,5 +12,7 @@ class ASRModelFactory:
return OpenAIWhisperASR()
elif CONFIG.ASR_ENGINE == "faster_whisper":
return FasterWhisperASR()
elif CONFIG.ASR_ENGINE == "whisperx":
return WhisperXASR()
else:
raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}")
import json
import os
from dataclasses import asdict
from typing import TextIO, BinaryIO
from dataclasses import asdict, is_dataclass
from typing import TextIO, BinaryIO, Union
import ffmpeg
import numpy as np
......@@ -23,14 +23,42 @@ class ResultWriter:
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
raise NotImplementedError
def format_segments_in_result(self, result: dict):
if "segments" in result:
# Check if result["segments"] is a list
if isinstance(result["segments"], list):
# Check if the list is empty
if not result["segments"]:
# Handle the empty list case, you can choose to leave it as is or set it to an empty list
pass
else:
# Check if the first item in the list is a dataclass instance
if is_dataclass(result["segments"][0]):
result["segments"] = [asdict(segment) for segment in result["segments"]]
# If it's already a list of dicts, leave it as is
elif isinstance(result["segments"][0], dict):
pass
else:
# Handle the case where the list contains neither dataclass instances nor dicts
# You can choose to leave it as is or raise an error
pass
elif isinstance(result["segments"], dict):
# If it's already a dict, leave it as is
pass
else:
# Handle the case where result["segments"] is neither a list nor a dict
# You can choose to leave it as is or raise an error
pass
return result
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
for segment in result["segments"]:
print(segment.text.strip(), file=file, flush=True)
......@@ -38,12 +66,13 @@ class WriteTXT(ResultWriter):
class WriteVTT(ResultWriter):
extension: str = "vtt"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
print("WEBVTT\n", file=file)
result = self.format_segments_in_result(result)
for segment in result["segments"]:
print(
f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
......@@ -52,14 +81,15 @@ class WriteVTT(ResultWriter):
class WriteSRT(ResultWriter):
extension: str = "srt"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
result = self.format_segments_in_result(result)
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
......@@ -77,20 +107,20 @@ class WriteTSV(ResultWriter):
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
result = self.format_segments_in_result(result)
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment.start), file=file, end="\t")
print(round(1000 * segment.end), file=file, end="\t")
print(segment.text.strip().replace("\t", " "), file=file, flush=True)
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
if "segments" in result:
result["segments"] = [asdict(segment) for segment in result["segments"]]
def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
result = self.format_segments_in_result(result)
json.dump(result, file)
......
......
......@@ -35,7 +35,6 @@ assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
app.mount("/assets", StaticFiles(directory=assets_path), name="static")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args,
......@@ -45,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
swagger_js_url="/assets/swagger-ui-bundle.js",
)
applications.get_swagger_ui_html = swagger_monkey_patch
......@@ -68,11 +66,37 @@ async def asr(
include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
),
] = False,
word_timestamps: bool = Query(default=False, description="Word level timestamps"),
word_timestamps: bool = Query(
default=False,
description="Word level timestamps",
include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
),
diarize: bool = Query(
default=False,
description="Diarize the input",
include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" and CONFIG.HF_TOKEN != "" else False),
),
min_speakers: Union[int, None] = Query(
default=None,
description="Min speakers in this file",
include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
),
max_speakers: Union[int, None] = Query(
default=None,
description="Max speakers in this file",
include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
),
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
):
result = asr_model.transcribe(
load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
load_audio(audio_file.file, encode),
task,
language,
initial_prompt,
vad_filter,
word_timestamps,
{"diarize": diarize, "min_speakers": min_speakers, "max_speakers": max_speakers},
output,
)
return StreamingResponse(
result,
......@@ -90,8 +114,11 @@ async def detect_language(
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
):
detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code,
"confidence": confidence}
return {
"detected_language": tokenizer.LANGUAGES[detected_lang_code],
"language_code": detected_lang_code,
"confidence": confidence,
}
@click.command()
......@@ -110,10 +137,7 @@ async def detect_language(
help="Port for the webservice (default: 9000)",
)
@click.version_option(version=projectMetadata["Version"])
def start(
host: str,
port: Optional[int] = None
):
def start(host: str, port: Optional[int] = None):
uvicorn.run(app, host=host, port=port)
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment