Skip to content
Snippets Groups Projects
Unverified Commit 0cd35b89 authored by charnesp's avatar charnesp Committed by GitHub
Browse files

Integrate Whisperx (#267)


* Initial commit for whisperx

* whisperx working

* Correct Dockerfile

* Add possibility to prioritize self-hosted runners

---------

Co-authored-by: default avatarCharles N <charles@chouette.vision>
parent f2f39fe1
No related branches found
No related tags found
No related merge requests found
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
{
"name": "Existing Docker Compose (Extend)",
// Update the 'dockerComposeFile' list if you have more compose files or use different names.
// The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
"dockerComposeFile": [
"../docker-compose.yml",
"docker-compose.yml"
],
// The 'service' property is the name of the service for the container that VS Code should
// use. Update this value and .devcontainer/docker-compose.yml to the real service name.
"service": "whisper-asr-webservice",
// The optional 'workspaceFolder' property is the path VS Code should open by default when
// connected. This is typically a file mount in .devcontainer/docker-compose.yml
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
// "overrideCommand": "/bin/sh -c 'while sleep 1000; do :; done'"
"overrideCommand": true
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Uncomment the next line if you want start specific services in your Docker Compose config.
// "runServices": [],
// Uncomment the next line if you want to keep your containers running after VS Code shuts down.
// "shutdownAction": "none",
// Uncomment the next line to run commands after the container is created.
// "postCreateCommand": "cat /etc/os-release",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
}
version: '3.4'
services:
# Update this to the name of the service you want to work with in your docker-compose.yml file
whisper-asr-webservice:
# Uncomment if you want to override the service's Dockerfile to one in the .devcontainer
# folder. Note that the path of the Dockerfile and context is relative to the *primary*
# docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
# array). The sample below assumes your primary file is in the root of your project.
#
# build:
# context: .
# dockerfile: .devcontainer/Dockerfile
env_file: .devcontainer/dev.env
environment:
ASR_ENGINE: ${ASR_ENGINE}
HF_TOKEN: ${HF_TOKEN}
volumes:
# Update this to wherever you want VS Code to mount the folder of your project
- ..:/workspaces:cached
# Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
# cap_add:
# - SYS_PTRACE
# security_opt:
# - seccomp:unconfined
# Overrides default command so things don't shut down after the process ends.
command: sleep infinity
...@@ -12,7 +12,7 @@ env: ...@@ -12,7 +12,7 @@ env:
REPO_NAME: ${{secrets.REPO_NAME}} REPO_NAME: ${{secrets.REPO_NAME}}
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: [self-hosted, ubuntu-latest]
strategy: strategy:
matrix: matrix:
include: include:
...@@ -22,6 +22,10 @@ jobs: ...@@ -22,6 +22,10 @@ jobs:
tag_extension: -gpu tag_extension: -gpu
platforms: linux/amd64 platforms: linux/amd64
steps: steps:
- name: Remove unnecessary files
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up QEMU - name: Set up QEMU
......
...@@ -42,3 +42,5 @@ pip-wheel-metadata ...@@ -42,3 +42,5 @@ pip-wheel-metadata
poetry/core/* poetry/core/*
public public
.devcontainer/dev.env
...@@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \ ...@@ -8,6 +8,8 @@ RUN export DEBIAN_FRONTEND=noninteractive \
pkg-config \ pkg-config \
yasm \ yasm \
ca-certificates \ ca-certificates \
gcc \
python3-dev \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1 RUN git clone https://github.com/FFmpeg/FFmpeg.git --depth 1 --branch n6.1.1 --single-branch /FFmpeg-6.1.1
...@@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui ...@@ -42,6 +44,12 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
FROM python:3.10-bookworm FROM python:3.10-bookworm
RUN export DEBIAN_FRONTEND=noninteractive \
&& apt-get -qq update \
&& apt-get -qq install --no-install-recommends \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/*
ENV POETRY_VENV=/app/.venv ENV POETRY_VENV=/app/.venv
RUN python3 -m venv $POETRY_VENV \ RUN python3 -m venv $POETRY_VENV \
...@@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass ...@@ -61,6 +69,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
RUN poetry config virtualenvs.in-project true RUN poetry config virtualenvs.in-project true
RUN poetry install RUN poetry install
RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
&& cd whisperX \
&& $POETRY_VENV/bin/pip install -e .
EXPOSE 9000 EXPOSE 9000
ENTRYPOINT ["whisper-asr-webservice"] ENTRYPOINT ["whisper-asr-webservice"]
...@@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui ...@@ -43,6 +43,13 @@ FROM swaggerapi/swagger-ui:v5.9.1 AS swagger-ui
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
ENV PYTHON_VERSION=3.10 ENV PYTHON_VERSION=3.10
RUN export DEBIAN_FRONTEND=noninteractive \
&& apt-get -qq update \
&& apt-get -qq install --no-install-recommends \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/*
ENV POETRY_VENV=/app/.venv ENV POETRY_VENV=/app/.venv
RUN export DEBIAN_FRONTEND=noninteractive \ RUN export DEBIAN_FRONTEND=noninteractive \
...@@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass ...@@ -79,6 +86,11 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass
RUN poetry install RUN poetry install
RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch RUN $POETRY_VENV/bin/pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch
RUN $POETRY_VENV/bin/pip install pandas transformers nltk pyannote.audio
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \
&& cd whisperX \
&& $POETRY_VENV/bin/pip install -e .
EXPOSE 9000 EXPOSE 9000
CMD whisper-asr-webservice CMD whisper-asr-webservice
...@@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models: ...@@ -13,6 +13,7 @@ Current release (v1.7.1) supports following whisper models:
- [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930) - [openai/whisper](https://github.com/openai/whisper)@[v20240930](https://github.com/openai/whisper/releases/tag/v20240930)
- [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0) - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper)@[v1.1.0](https://github.com/SYSTRAN/faster-whisper/releases/tag/v1.1.0)
- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1)
## Quick Usage ## Quick Usage
......
...@@ -13,7 +13,10 @@ class ASRModel(ABC): ...@@ -13,7 +13,10 @@ class ASRModel(ABC):
""" """
Abstract base class for ASR (Automatic Speech Recognition) models. Abstract base class for ASR (Automatic Speech Recognition) models.
""" """
model = None model = None
diarize_model = None # used for WhisperX
x_models = dict() # used for WhisperX
model_lock = Lock() model_lock = Lock()
last_activity_time = time.time() last_activity_time = time.time()
...@@ -28,13 +31,16 @@ class ASRModel(ABC): ...@@ -28,13 +31,16 @@ class ASRModel(ABC):
pass pass
@abstractmethod @abstractmethod
def transcribe(self, def transcribe(
self,
audio, audio,
task: Union[str, None], task: Union[str, None],
language: Union[str, None], language: Union[str, None],
initial_prompt: Union[str, None], initial_prompt: Union[str, None],
vad_filter: Union[bool, None], vad_filter: Union[bool, None],
word_timestamps: Union[bool, None] word_timestamps: Union[bool, None],
options: Union[dict, None],
output,
): ):
""" """
Perform transcription on the given audio file. Perform transcription on the given audio file.
...@@ -52,7 +58,8 @@ class ASRModel(ABC): ...@@ -52,7 +58,8 @@ class ASRModel(ABC):
""" """
Monitors the idleness of the ASR model and releases the model if it has been idle for too long. Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
""" """
if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return if CONFIG.MODEL_IDLE_TIMEOUT <= 0:
return
while True: while True:
time.sleep(15) time.sleep(15)
if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT: if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
...@@ -68,4 +75,6 @@ class ASRModel(ABC): ...@@ -68,4 +75,6 @@ class ASRModel(ABC):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
self.model = None self.model = None
self.diarize_model = None
self.x_models = dict()
print("Model unloaded due to timeout") print("Model unloaded due to timeout")
...@@ -32,6 +32,7 @@ class FasterWhisperASR(ASRModel): ...@@ -32,6 +32,7 @@ class FasterWhisperASR(ASRModel):
initial_prompt: Union[str, None], initial_prompt: Union[str, None],
vad_filter: Union[bool, None], vad_filter: Union[bool, None],
word_timestamps: Union[bool, None], word_timestamps: Union[bool, None],
options: Union[dict, None],
output, output,
): ):
self.last_activity_time = time.time() self.last_activity_time = time.time()
......
from typing import BinaryIO, Union
from io import StringIO
import whisperx
import whisper
from whisperx.utils import SubtitlesWriter, ResultWriter
from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
from app.utils import WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
class WhisperXASR(ASRModel):
def __init__(self):
self.x_models = dict()
def load_model(self):
asr_options = {"without_timestamps": False}
self.model = whisperx.load_model(
CONFIG.MODEL_NAME, device=CONFIG.DEVICE, compute_type="float32", asr_options=asr_options
)
if CONFIG.HF_TOKEN != "":
self.diarize_model = whisperx.DiarizationPipeline(use_auth_token=CONFIG.HF_TOKEN, device=CONFIG.DEVICE)
def transcribe(
self,
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
options: Union[dict, None],
output,
):
options_dict = {"task": task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
with self.model_lock:
if self.model is None:
self.load_model()
result = self.model.transcribe(audio, **options_dict)
# Load the required model and cache it
# If we transcribe models in many different languages, this may lead to OOM propblems
if result["language"] in self.x_models:
model_x, metadata = self.x_models[result["language"]]
else:
self.x_models[result["language"]] = whisperx.load_align_model(
language_code=result["language"], device=CONFIG.DEVICE
)
model_x, metadata = self.x_models[result["language"]]
# Align whisper output
result = whisperx.align(
result["segments"], model_x, metadata, audio, CONFIG.DEVICE, return_char_alignments=False
)
if options.get("diarize", False):
if CONFIG.HF_TOKEN == "":
print("Warning! HF_TOKEN is not set. Diarization may not work as expected.")
min_speakers = options.get("min_speakers", None)
max_speakers = options.get("max_speakers", None)
# add min/max number of speakers if known
diarize_segments = self.diarize_model(audio, min_speakers, max_speakers)
result = whisperx.assign_word_speakers(diarize_segments, result)
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)
return output_file
def language_detection(self, audio):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
# detect the spoken language
with self.model_lock:
if self.model is None:
self.load_model()
_, probs = self.model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get)
return detected_lang_code
def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
if output == "srt":
if CONFIG.HF_TOKEN != "":
WriteSRT(SubtitlesWriter).write_result(result, file=file, options={})
else:
WriteSRT(ResultWriter).write_result(result, file=file, options={})
elif output == "vtt":
if CONFIG.HF_TOKEN != "":
WriteVTT(SubtitlesWriter).write_result(result, file=file, options={})
else:
WriteVTT(ResultWriter).write_result(result, file=file, options={})
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file, options={})
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file, options={})
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file, options={})
else:
return 'Please select an output method!'
...@@ -16,15 +16,9 @@ class OpenAIWhisperASR(ASRModel): ...@@ -16,15 +16,9 @@ class OpenAIWhisperASR(ASRModel):
def load_model(self): def load_model(self):
if torch.cuda.is_available(): if torch.cuda.is_available():
self.model = whisper.load_model( self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda()
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
).cuda()
else: else:
self.model = whisper.load_model( self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH)
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
)
Thread(target=self.monitor_idleness, daemon=True).start() Thread(target=self.monitor_idleness, daemon=True).start()
...@@ -36,12 +30,14 @@ class OpenAIWhisperASR(ASRModel): ...@@ -36,12 +30,14 @@ class OpenAIWhisperASR(ASRModel):
initial_prompt: Union[str, None], initial_prompt: Union[str, None],
vad_filter: Union[bool, None], vad_filter: Union[bool, None],
word_timestamps: Union[bool, None], word_timestamps: Union[bool, None],
options: Union[dict, None],
output, output,
): ):
self.last_activity_time = time.time() self.last_activity_time = time.time()
with self.model_lock: with self.model_lock:
if self.model is None: self.load_model() if self.model is None:
self.load_model()
options_dict = {"task": task} options_dict = {"task": task}
if language: if language:
...@@ -64,7 +60,8 @@ class OpenAIWhisperASR(ASRModel): ...@@ -64,7 +60,8 @@ class OpenAIWhisperASR(ASRModel):
self.last_activity_time = time.time() self.last_activity_time = time.time()
with self.model_lock: with self.model_lock:
if self.model is None: self.load_model() if self.model is None:
self.load_model()
# load audio and pad/trim it to fit 30 seconds # load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio) audio = whisper.pad_or_trim(audio)
......
...@@ -11,6 +11,9 @@ class CONFIG: ...@@ -11,6 +11,9 @@ class CONFIG:
# Determine the ASR engine ('faster_whisper' or 'openai_whisper') # Determine the ASR engine ('faster_whisper' or 'openai_whisper')
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper") ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
# Retrieve Huggingface Token
HF_TOKEN = os.getenv("HF_TOKEN", "")
# Determine the computation device (GPU or CPU) # Determine the computation device (GPU or CPU)
DEVICE = os.getenv("ASR_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") DEVICE = os.getenv("ASR_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
......
from app.asr_models.asr_model import ASRModel from app.asr_models.asr_model import ASRModel
from app.asr_models.faster_whisper_engine import FasterWhisperASR from app.asr_models.faster_whisper_engine import FasterWhisperASR
from app.asr_models.openai_whisper_engine import OpenAIWhisperASR from app.asr_models.openai_whisper_engine import OpenAIWhisperASR
from app.asr_models.mbain_whisperx_engine import WhisperXASR
from app.config import CONFIG from app.config import CONFIG
...@@ -11,5 +12,7 @@ class ASRModelFactory: ...@@ -11,5 +12,7 @@ class ASRModelFactory:
return OpenAIWhisperASR() return OpenAIWhisperASR()
elif CONFIG.ASR_ENGINE == "faster_whisper": elif CONFIG.ASR_ENGINE == "faster_whisper":
return FasterWhisperASR() return FasterWhisperASR()
elif CONFIG.ASR_ENGINE == "whisperx":
return WhisperXASR()
else: else:
raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}") raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}")
import json import json
import os import os
from dataclasses import asdict from dataclasses import asdict, is_dataclass
from typing import TextIO, BinaryIO from typing import TextIO, BinaryIO, Union
import ffmpeg import ffmpeg
import numpy as np import numpy as np
...@@ -23,14 +23,42 @@ class ResultWriter: ...@@ -23,14 +23,42 @@ class ResultWriter:
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f) self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
raise NotImplementedError raise NotImplementedError
def format_segments_in_result(self, result: dict):
if "segments" in result:
# Check if result["segments"] is a list
if isinstance(result["segments"], list):
# Check if the list is empty
if not result["segments"]:
# Handle the empty list case, you can choose to leave it as is or set it to an empty list
pass
else:
# Check if the first item in the list is a dataclass instance
if is_dataclass(result["segments"][0]):
result["segments"] = [asdict(segment) for segment in result["segments"]]
# If it's already a list of dicts, leave it as is
elif isinstance(result["segments"][0], dict):
pass
else:
# Handle the case where the list contains neither dataclass instances nor dicts
# You can choose to leave it as is or raise an error
pass
elif isinstance(result["segments"], dict):
# If it's already a dict, leave it as is
pass
else:
# Handle the case where result["segments"] is neither a list nor a dict
# You can choose to leave it as is or raise an error
pass
return result
class WriteTXT(ResultWriter): class WriteTXT(ResultWriter):
extension: str = "txt" extension: str = "txt"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
for segment in result["segments"]: for segment in result["segments"]:
print(segment.text.strip(), file=file, flush=True) print(segment.text.strip(), file=file, flush=True)
...@@ -38,12 +66,13 @@ class WriteTXT(ResultWriter): ...@@ -38,12 +66,13 @@ class WriteTXT(ResultWriter):
class WriteVTT(ResultWriter): class WriteVTT(ResultWriter):
extension: str = "vtt" extension: str = "vtt"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
print("WEBVTT\n", file=file) print("WEBVTT\n", file=file)
result = self.format_segments_in_result(result)
for segment in result["segments"]: for segment in result["segments"]:
print( print(
f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n" f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment.text.strip().replace('-->', '->')}\n", f"{segment['text'].strip().replace('-->', '->')}\n",
file=file, file=file,
flush=True, flush=True,
) )
...@@ -52,14 +81,15 @@ class WriteVTT(ResultWriter): ...@@ -52,14 +81,15 @@ class WriteVTT(ResultWriter):
class WriteSRT(ResultWriter): class WriteSRT(ResultWriter):
extension: str = "srt" extension: str = "srt"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
result = self.format_segments_in_result(result)
for i, segment in enumerate(result["segments"], start=1): for i, segment in enumerate(result["segments"], start=1):
# write srt lines # write srt lines
print( print(
f"{i}\n" f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> " f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n" f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment.text.strip().replace('-->', '->')}\n", f"{segment['text'].strip().replace('-->', '->')}\n",
file=file, file=file,
flush=True, flush=True,
) )
...@@ -77,20 +107,20 @@ class WriteTSV(ResultWriter): ...@@ -77,20 +107,20 @@ class WriteTSV(ResultWriter):
extension: str = "tsv" extension: str = "tsv"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
result = self.format_segments_in_result(result)
print("start", "end", "text", sep="\t", file=file) print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]: for segment in result["segments"]:
print(round(1000 * segment.start), file=file, end="\t") print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment.end), file=file, end="\t") print(round(1000 * segment["end"]), file=file, end="\t")
print(segment.text.strip().replace("\t", " "), file=file, flush=True) print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter): class WriteJSON(ResultWriter):
extension: str = "json" extension: str = "json"
def write_result(self, result: dict, file: TextIO): def write_result(self, result: dict, file: TextIO, options: Union[dict, None]):
if "segments" in result: result = self.format_segments_in_result(result)
result["segments"] = [asdict(segment) for segment in result["segments"]]
json.dump(result, file) json.dump(result, file)
......
...@@ -35,7 +35,6 @@ assets_path = os.getcwd() + "/swagger-ui-assets" ...@@ -35,7 +35,6 @@ assets_path = os.getcwd() + "/swagger-ui-assets"
if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"): if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
app.mount("/assets", StaticFiles(directory=assets_path), name="static") app.mount("/assets", StaticFiles(directory=assets_path), name="static")
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, *args,
...@@ -45,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/ ...@@ -45,7 +44,6 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
swagger_js_url="/assets/swagger-ui-bundle.js", swagger_js_url="/assets/swagger-ui-bundle.js",
) )
applications.get_swagger_ui_html = swagger_monkey_patch applications.get_swagger_ui_html = swagger_monkey_patch
...@@ -68,11 +66,37 @@ async def asr( ...@@ -68,11 +66,37 @@ async def asr(
include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False), include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
), ),
] = False, ] = False,
word_timestamps: bool = Query(default=False, description="Word level timestamps"), word_timestamps: bool = Query(
default=False,
description="Word level timestamps",
include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
),
diarize: bool = Query(
default=False,
description="Diarize the input",
include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" and CONFIG.HF_TOKEN != "" else False),
),
min_speakers: Union[int, None] = Query(
default=None,
description="Min speakers in this file",
include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
),
max_speakers: Union[int, None] = Query(
default=None,
description="Max speakers in this file",
include_in_schema=(True if CONFIG.ASR_ENGINE == "whisperx" else False),
),
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
): ):
result = asr_model.transcribe( result = asr_model.transcribe(
load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output load_audio(audio_file.file, encode),
task,
language,
initial_prompt,
vad_filter,
word_timestamps,
{"diarize": diarize, "min_speakers": min_speakers, "max_speakers": max_speakers},
output,
) )
return StreamingResponse( return StreamingResponse(
result, result,
...@@ -90,8 +114,11 @@ async def detect_language( ...@@ -90,8 +114,11 @@ async def detect_language(
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"), encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
): ):
detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode)) detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code, return {
"confidence": confidence} "detected_language": tokenizer.LANGUAGES[detected_lang_code],
"language_code": detected_lang_code,
"confidence": confidence,
}
@click.command() @click.command()
...@@ -110,10 +137,7 @@ async def detect_language( ...@@ -110,10 +137,7 @@ async def detect_language(
help="Port for the webservice (default: 9000)", help="Port for the webservice (default: 9000)",
) )
@click.version_option(version=projectMetadata["Version"]) @click.version_option(version=projectMetadata["Version"])
def start( def start(host: str, port: Optional[int] = None):
host: str,
port: Optional[int] = None
):
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment