diff --git a/CHANGELOG.md b/CHANGELOG.md index 481e1fc98fca90092797330ed71c738ecff28fe2..b2f280ec063df786fc3f5abc70c5a189c5004b55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Unreleased - Added detection confidence to langauge detection endpoint - Set mel generation to adjust n_dims automatically to match the loaded model +### Added + + - Refactor classes, Add comments, implement abstract methods, and add factory method for engine selection + [1.6.0] (2024-10-06) -------------------- diff --git a/app/asr_models/asr_model.py b/app/asr_models/asr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2205083ed48e5e88267aeb3afa1de2d2e61051 --- /dev/null +++ b/app/asr_models/asr_model.py @@ -0,0 +1,71 @@ +import gc +import time +from abc import ABC, abstractmethod +from threading import Lock +from typing import Union + +import torch + +from app.config import CONFIG + + +class ASRModel(ABC): + """ + Abstract base class for ASR (Automatic Speech Recognition) models. + """ + model = None + model_lock = Lock() + last_activity_time = time.time() + + def __init__(self): + pass + + @abstractmethod + def load_model(self): + """ + Loads the model from the specified path. + """ + pass + + @abstractmethod + def transcribe(self, + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + vad_filter: Union[bool, None], + word_timestamps: Union[bool, None] + ): + """ + Perform transcription on the given audio file. + """ + pass + + @abstractmethod + def language_detection(self, audio): + """ + Perform language detection on the given audio file. + """ + pass + + def monitor_idleness(self): + """ + Monitors the idleness of the ASR model and releases the model if it has been idle for too long. + """ + if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return + while True: + time.sleep(15) + if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT: + with self.model_lock: + self.release_model() + break + + def release_model(self): + """ + Unloads the model from memory and clears any cached GPU memory. + """ + del self.model + torch.cuda.empty_cache() + gc.collect() + self.model = None + print("Model unloaded due to timeout") diff --git a/app/asr_models/faster_whisper_engine.py b/app/asr_models/faster_whisper_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e577d6761f81f7c505d938305c0a5d49bb0c0b --- /dev/null +++ b/app/asr_models/faster_whisper_engine.py @@ -0,0 +1,98 @@ +import time +from io import StringIO +from threading import Thread +from typing import BinaryIO, Union + +import whisper +from faster_whisper import WhisperModel + +from app.asr_models.asr_model import ASRModel +from app.config import CONFIG +from app.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT + + +class FasterWhisperASR(ASRModel): + + def load_model(self): + + self.model = WhisperModel( + model_size_or_path=CONFIG.MODEL_NAME, + device=CONFIG.DEVICE, + compute_type=CONFIG.MODEL_QUANTIZATION, + download_root=CONFIG.MODEL_PATH + ) + + Thread(target=self.monitor_idleness, daemon=True).start() + + def transcribe( + self, + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + vad_filter: Union[bool, None], + word_timestamps: Union[bool, None], + output, + ): + + print("faster whisper") + self.last_activity_time = time.time() + + with self.model_lock: + if self.model is None: self.load_model() + + options_dict = {"task": task} + if language: + options_dict["language"] = language + if initial_prompt: + options_dict["initial_prompt"] = initial_prompt + if vad_filter: + options_dict["vad_filter"] = True + if word_timestamps: + options_dict["word_timestamps"] = True + with self.model_lock: + segments = [] + text = "" + segment_generator, info = self.model.transcribe(audio, beam_size=5, **options_dict) + for segment in segment_generator: + segments.append(segment) + text = text + segment.text + result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text} + + output_file = StringIO() + self.write_result(result, output_file, output) + output_file.seek(0) + + return output_file + + def language_detection(self, audio): + + self.last_activity_time = time.time() + + with self.model_lock: + if self.model is None: self.load_model() + + # load audio and pad/trim it to fit 30 seconds + audio = whisper.pad_or_trim(audio) + + # detect the spoken language + with self.model_lock: + segments, info = self.model.transcribe(audio, beam_size=5) + detected_lang_code = info.language + detected_language_confidence = info.language_probability + + return detected_lang_code, detected_language_confidence + + def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]): + if output == "srt": + WriteSRT(ResultWriter).write_result(result, file=file) + elif output == "vtt": + WriteVTT(ResultWriter).write_result(result, file=file) + elif output == "tsv": + WriteTSV(ResultWriter).write_result(result, file=file) + elif output == "json": + WriteJSON(ResultWriter).write_result(result, file=file) + elif output == "txt": + WriteTXT(ResultWriter).write_result(result, file=file) + else: + return "Please select an output method!" diff --git a/app/asr_models/openai_whisper_engine.py b/app/asr_models/openai_whisper_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d844190252e321b48c6f2e4c4a7fbba77f18b5 --- /dev/null +++ b/app/asr_models/openai_whisper_engine.py @@ -0,0 +1,96 @@ +import time +from io import StringIO +from threading import Thread +from typing import BinaryIO, Union + +import torch +import whisper +from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT + +from app.asr_models.asr_model import ASRModel +from app.config import CONFIG + + +class OpenAIWhisperASR(ASRModel): + + def load_model(self): + + if torch.cuda.is_available(): + self.model = whisper.load_model( + name=CONFIG.MODEL_NAME, + download_root=CONFIG.MODEL_PATH + ).cuda() + else: + self.model = whisper.load_model( + name=CONFIG.MODEL_NAME, + download_root=CONFIG.MODEL_PATH + ) + + Thread(target=self.monitor_idleness, daemon=True).start() + + def transcribe( + self, + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + vad_filter: Union[bool, None], + word_timestamps: Union[bool, None], + output, + ): + print("whisper") + self.last_activity_time = time.time() + + with self.model_lock: + if self.model is None: self.load_model() + + options_dict = {"task": task} + if language: + options_dict["language"] = language + if initial_prompt: + options_dict["initial_prompt"] = initial_prompt + if word_timestamps: + options_dict["word_timestamps"] = word_timestamps + with self.model_lock: + result = self.model.transcribe(audio, **options_dict) + + output_file = StringIO() + self.write_result(result, output_file, output) + output_file.seek(0) + + return output_file + + def language_detection(self, audio): + + self.last_activity_time = time.time() + + with self.model_lock: + if self.model is None: self.load_model() + + # load audio and pad/trim it to fit 30 seconds + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio, self.model.dims.n_mels).to(self.model.device) + + # detect the spoken language + with self.model_lock: + _, probs = self.model.detect_language(mel) + detected_lang_code = max(probs, key=probs.get) + + return detected_lang_code, probs[max(probs)] + + def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]): + options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False} + if output == "srt": + WriteSRT(ResultWriter).write_result(result, file=file, options=options) + elif output == "vtt": + WriteVTT(ResultWriter).write_result(result, file=file, options=options) + elif output == "tsv": + WriteTSV(ResultWriter).write_result(result, file=file, options=options) + elif output == "json": + WriteJSON(ResultWriter).write_result(result, file=file, options=options) + elif output == "txt": + WriteTXT(ResultWriter).write_result(result, file=file, options=options) + else: + return "Please select an output method!" diff --git a/app/factory/asr_model_factory.py b/app/factory/asr_model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..0f97fb53ae0a5ca20caec09dc40c388efcdb863a --- /dev/null +++ b/app/factory/asr_model_factory.py @@ -0,0 +1,15 @@ +from app.asr_models.asr_model import ASRModel +from app.asr_models.faster_whisper_engine import FasterWhisperASR +from app.asr_models.openai_whisper_engine import OpenAIWhisperASR +from app.config import CONFIG + + +class ASRModelFactory: + @staticmethod + def create_asr_model() -> ASRModel: + if CONFIG.ASR_ENGINE == "openai_whisper": + return OpenAIWhisperASR() + elif CONFIG.ASR_ENGINE == "faster_whisper": + return FasterWhisperASR() + else: + raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}") diff --git a/app/faster_whisper/core.py b/app/faster_whisper/core.py deleted file mode 100644 index f273309b0fa78ea45d7f2c8e2cb1e380bea52e46..0000000000000000000000000000000000000000 --- a/app/faster_whisper/core.py +++ /dev/null @@ -1,126 +0,0 @@ -import gc -import time -from io import StringIO -from threading import Lock, Thread -from typing import BinaryIO, Union - -import torch -import whisper -from faster_whisper import WhisperModel - -from app.config import CONFIG -from app.faster_whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT - -model = None -model_lock = Lock() -last_activity_time = time.time() - - -def monitor_idleness(): - global model - if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return - while True: - time.sleep(15) - if time.time() - last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT: - with model_lock: - release_model() - break - - -def load_model(): - global model, device, model_quantization - - model = WhisperModel( - model_size_or_path=CONFIG.MODEL_NAME, - device=CONFIG.DEVICE, - compute_type=CONFIG.MODEL_QUANTIZATION, - download_root=CONFIG.MODEL_PATH - ) - - Thread(target=monitor_idleness, daemon=True).start() - - -load_model() - - -def release_model(): - global model - del model - torch.cuda.empty_cache() - gc.collect() - model = None - print("Model unloaded due to timeout") - - -def transcribe( - audio, - task: Union[str, None], - language: Union[str, None], - initial_prompt: Union[str, None], - vad_filter: Union[bool, None], - word_timestamps: Union[bool, None], - output, -): - global last_activity_time - last_activity_time = time.time() - - with model_lock: - if model is None: load_model() - - options_dict = {"task": task} - if language: - options_dict["language"] = language - if initial_prompt: - options_dict["initial_prompt"] = initial_prompt - if vad_filter: - options_dict["vad_filter"] = True - if word_timestamps: - options_dict["word_timestamps"] = True - with model_lock: - segments = [] - text = "" - segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict) - for segment in segment_generator: - segments.append(segment) - text = text + segment.text - result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text} - - output_file = StringIO() - write_result(result, output_file, output) - output_file.seek(0) - - return output_file - - -def language_detection(audio): - global last_activity_time - last_activity_time = time.time() - - with model_lock: - if model is None: load_model() - - # load audio and pad/trim it to fit 30 seconds - audio = whisper.pad_or_trim(audio) - - # detect the spoken language - with model_lock: - segments, info = model.transcribe(audio, beam_size=5) - detected_lang_code = info.language - detected_language_confidence = info.language_probability - - return detected_lang_code, detected_language_confidence - - -def write_result(result: dict, file: BinaryIO, output: Union[str, None]): - if output == "srt": - WriteSRT(ResultWriter).write_result(result, file=file) - elif output == "vtt": - WriteVTT(ResultWriter).write_result(result, file=file) - elif output == "tsv": - WriteTSV(ResultWriter).write_result(result, file=file) - elif output == "json": - WriteJSON(ResultWriter).write_result(result, file=file) - elif output == "txt": - WriteTXT(ResultWriter).write_result(result, file=file) - else: - return "Please select an output method!" diff --git a/app/openai_whisper/core.py b/app/openai_whisper/core.py deleted file mode 100644 index d8e7d5f0a0509f93fab789e77fd8c9141e162d50..0000000000000000000000000000000000000000 --- a/app/openai_whisper/core.py +++ /dev/null @@ -1,124 +0,0 @@ -import gc -import time -from io import StringIO -from threading import Lock, Thread -from typing import BinaryIO, Union - -import torch -import whisper -from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT - -from app.config import CONFIG - -model = None -model_lock = Lock() -last_activity_time = time.time() - - -def monitor_idleness(): - global model - if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return - while True: - time.sleep(15) - if time.time() - last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT: - with model_lock: - release_model() - break - - -def load_model(): - global model - - if torch.cuda.is_available(): - model = whisper.load_model( - name=CONFIG.MODEL_NAME, - download_root=CONFIG.MODEL_PATH - ).cuda() - else: - model = whisper.load_model( - name=CONFIG.MODEL_NAME, - download_root=CONFIG.MODEL_PATH - ) - - Thread(target=monitor_idleness, daemon=True).start() - - -load_model() - - -def release_model(): - global model - del model - torch.cuda.empty_cache() - gc.collect() - model = None - print("Model unloaded due to timeout") - - -def transcribe( - audio, - task: Union[str, None], - language: Union[str, None], - initial_prompt: Union[str, None], - vad_filter: Union[bool, None], - word_timestamps: Union[bool, None], - output, -): - global last_activity_time - last_activity_time = time.time() - - with model_lock: - if model is None: load_model() - - options_dict = {"task": task} - if language: - options_dict["language"] = language - if initial_prompt: - options_dict["initial_prompt"] = initial_prompt - if word_timestamps: - options_dict["word_timestamps"] = word_timestamps - with model_lock: - result = model.transcribe(audio, **options_dict) - - output_file = StringIO() - write_result(result, output_file, output) - output_file.seek(0) - - return output_file - - -def language_detection(audio): - global last_activity_time - last_activity_time = time.time() - - with model_lock: - if model is None: load_model() - - # load audio and pad/trim it to fit 30 seconds - audio = whisper.pad_or_trim(audio) - - # make log-Mel spectrogram and move to the same device as the model - mel = whisper.log_mel_spectrogram(audio, model.dims.n_mels).to(model.device) - - # detect the spoken language - with model_lock: - _, probs = model.detect_language(mel) - detected_lang_code = max(probs, key=probs.get) - - return detected_lang_code, probs[max(probs)] - - -def write_result(result: dict, file: BinaryIO, output: Union[str, None]): - options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False} - if output == "srt": - WriteSRT(ResultWriter).write_result(result, file=file, options=options) - elif output == "vtt": - WriteVTT(ResultWriter).write_result(result, file=file, options=options) - elif output == "tsv": - WriteTSV(ResultWriter).write_result(result, file=file, options=options) - elif output == "json": - WriteJSON(ResultWriter).write_result(result, file=file, options=options) - elif output == "txt": - WriteTXT(ResultWriter).write_result(result, file=file, options=options) - else: - return "Please select an output method!" diff --git a/app/faster_whisper/utils.py b/app/utils.py similarity index 67% rename from app/faster_whisper/utils.py rename to app/utils.py index 034c63937a304372999f0b9b60793e04208e5d5e..0b51281a0793aef2063ad475fd8a42b05ab94481 100644 --- a/app/faster_whisper/utils.py +++ b/app/utils.py @@ -1,9 +1,13 @@ import json import os -from typing import TextIO +from typing import TextIO, BinaryIO +import ffmpeg +import numpy as np from faster_whisper.utils import format_timestamp +from app.config import CONFIG + class ResultWriter: extension: str @@ -85,3 +89,36 @@ class WriteJSON(ResultWriter): def write_result(self, result: dict, file: TextIO): json.dump(result, file) + + +def load_audio(file: BinaryIO, encode=True, sr: int = CONFIG.SAMPLE_RATE): + """ + Open an audio file object and read as mono waveform, resampling as necessary. + Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object + Parameters + ---------- + file: BinaryIO + The audio file like object + encode: Boolean + If true, encode audio stream to WAV before sending to whisper + sr: int + The sample rate to resample the audio if necessary + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + if encode: + try: + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + out, _ = ( + ffmpeg.input("pipe:", threads=0) + .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) + .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read()) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + else: + out = file.read() + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 diff --git a/app/webservice.py b/app/webservice.py index 0e0dd2d21f6ccc3dc69ec2c4eb981e2846021160..7158645a467702a07baf8b997ce825712863ccf1 100644 --- a/app/webservice.py +++ b/app/webservice.py @@ -1,12 +1,10 @@ import importlib.metadata import os from os import path -from typing import Annotated, BinaryIO, Optional, Union +from typing import Annotated, Optional, Union from urllib.parse import quote import click -import ffmpeg -import numpy as np import uvicorn from fastapi import FastAPI, File, Query, UploadFile, applications from fastapi.openapi.docs import get_swagger_ui_html @@ -14,13 +12,13 @@ from fastapi.responses import RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from whisper import tokenizer -ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper") -if ASR_ENGINE == "faster_whisper": - from app.faster_whisper.core import language_detection, transcribe -else: - from app.openai_whisper.core import language_detection, transcribe +from app.config import CONFIG +from app.factory.asr_model_factory import ASRModelFactory +from app.utils import load_audio + +asr_model = ASRModelFactory.create_asr_model() +asr_model.load_model() -SAMPLE_RATE = 16000 LANGUAGE_CODES = sorted(tokenizer.LANGUAGES.keys()) projectMetadata = importlib.metadata.metadata("whisper-asr-webservice") @@ -67,20 +65,20 @@ async def asr( bool | None, Query( description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech", - include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False), + include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False), ), ] = False, word_timestamps: bool = Query(default=False, description="Word level timestamps"), output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), ): - result = transcribe( + result = asr_model.transcribe( load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output ) return StreamingResponse( result, media_type="text/plain", headers={ - "Asr-Engine": ASR_ENGINE, + "Asr-Engine": CONFIG.ASR_ENGINE, "Content-Disposition": f'attachment; filename="{quote(audio_file.filename)}.{output}"', }, ) @@ -91,44 +89,11 @@ async def detect_language( audio_file: UploadFile = File(...), # noqa: B008 encode: bool = Query(default=True, description="Encode audio first through FFmpeg"), ): - detected_lang_code, confidence = language_detection(load_audio(audio_file.file, encode)) + detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode)) return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code, "confidence": confidence} -def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE): - """ - Open an audio file object and read as mono waveform, resampling as necessary. - Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object - Parameters - ---------- - file: BinaryIO - The audio file like object - encode: Boolean - If true, encode audio stream to WAV before sending to whisper - sr: int - The sample rate to resample the audio if necessary - Returns - ------- - A NumPy array containing the audio waveform, in float32 dtype. - """ - if encode: - try: - # This launches a subprocess to decode audio while down-mixing and resampling as necessary. - # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. - out, _ = ( - ffmpeg.input("pipe:", threads=0) - .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) - .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read()) - ) - except ffmpeg.Error as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e - else: - out = file.read() - - return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 - - @click.command() @click.option( "-h",