diff --git a/CHANGELOG.md b/CHANGELOG.md
index 481e1fc98fca90092797330ed71c738ecff28fe2..b2f280ec063df786fc3f5abc70c5a189c5004b55 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,10 @@ Unreleased
- Added detection confidence to langauge detection endpoint
- Set mel generation to adjust n_dims automatically to match the loaded model
+### Added
+
+ - Refactor classes, Add comments, implement abstract methods, and add factory method for engine selection
+
[1.6.0] (2024-10-06)
--------------------
diff --git a/app/asr_models/asr_model.py b/app/asr_models/asr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc2205083ed48e5e88267aeb3afa1de2d2e61051
--- /dev/null
+++ b/app/asr_models/asr_model.py
@@ -0,0 +1,71 @@
+import gc
+import time
+from abc import ABC, abstractmethod
+from threading import Lock
+from typing import Union
+
+import torch
+
+from app.config import CONFIG
+
+
+class ASRModel(ABC):
+ """
+ Abstract base class for ASR (Automatic Speech Recognition) models.
+ """
+ model = None
+ model_lock = Lock()
+ last_activity_time = time.time()
+
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def load_model(self):
+ """
+ Loads the model from the specified path.
+ """
+ pass
+
+ @abstractmethod
+ def transcribe(self,
+ audio,
+ task: Union[str, None],
+ language: Union[str, None],
+ initial_prompt: Union[str, None],
+ vad_filter: Union[bool, None],
+ word_timestamps: Union[bool, None]
+ ):
+ """
+ Perform transcription on the given audio file.
+ """
+ pass
+
+ @abstractmethod
+ def language_detection(self, audio):
+ """
+ Perform language detection on the given audio file.
+ """
+ pass
+
+ def monitor_idleness(self):
+ """
+ Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
+ """
+ if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
+ while True:
+ time.sleep(15)
+ if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
+ with self.model_lock:
+ self.release_model()
+ break
+
+ def release_model(self):
+ """
+ Unloads the model from memory and clears any cached GPU memory.
+ """
+ del self.model
+ torch.cuda.empty_cache()
+ gc.collect()
+ self.model = None
+ print("Model unloaded due to timeout")
diff --git a/app/asr_models/faster_whisper_engine.py b/app/asr_models/faster_whisper_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0e577d6761f81f7c505d938305c0a5d49bb0c0b
--- /dev/null
+++ b/app/asr_models/faster_whisper_engine.py
@@ -0,0 +1,98 @@
+import time
+from io import StringIO
+from threading import Thread
+from typing import BinaryIO, Union
+
+import whisper
+from faster_whisper import WhisperModel
+
+from app.asr_models.asr_model import ASRModel
+from app.config import CONFIG
+from app.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
+
+
+class FasterWhisperASR(ASRModel):
+
+ def load_model(self):
+
+ self.model = WhisperModel(
+ model_size_or_path=CONFIG.MODEL_NAME,
+ device=CONFIG.DEVICE,
+ compute_type=CONFIG.MODEL_QUANTIZATION,
+ download_root=CONFIG.MODEL_PATH
+ )
+
+ Thread(target=self.monitor_idleness, daemon=True).start()
+
+ def transcribe(
+ self,
+ audio,
+ task: Union[str, None],
+ language: Union[str, None],
+ initial_prompt: Union[str, None],
+ vad_filter: Union[bool, None],
+ word_timestamps: Union[bool, None],
+ output,
+ ):
+
+ print("faster whisper")
+ self.last_activity_time = time.time()
+
+ with self.model_lock:
+ if self.model is None: self.load_model()
+
+ options_dict = {"task": task}
+ if language:
+ options_dict["language"] = language
+ if initial_prompt:
+ options_dict["initial_prompt"] = initial_prompt
+ if vad_filter:
+ options_dict["vad_filter"] = True
+ if word_timestamps:
+ options_dict["word_timestamps"] = True
+ with self.model_lock:
+ segments = []
+ text = ""
+ segment_generator, info = self.model.transcribe(audio, beam_size=5, **options_dict)
+ for segment in segment_generator:
+ segments.append(segment)
+ text = text + segment.text
+ result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}
+
+ output_file = StringIO()
+ self.write_result(result, output_file, output)
+ output_file.seek(0)
+
+ return output_file
+
+ def language_detection(self, audio):
+
+ self.last_activity_time = time.time()
+
+ with self.model_lock:
+ if self.model is None: self.load_model()
+
+ # load audio and pad/trim it to fit 30 seconds
+ audio = whisper.pad_or_trim(audio)
+
+ # detect the spoken language
+ with self.model_lock:
+ segments, info = self.model.transcribe(audio, beam_size=5)
+ detected_lang_code = info.language
+ detected_language_confidence = info.language_probability
+
+ return detected_lang_code, detected_language_confidence
+
+ def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
+ if output == "srt":
+ WriteSRT(ResultWriter).write_result(result, file=file)
+ elif output == "vtt":
+ WriteVTT(ResultWriter).write_result(result, file=file)
+ elif output == "tsv":
+ WriteTSV(ResultWriter).write_result(result, file=file)
+ elif output == "json":
+ WriteJSON(ResultWriter).write_result(result, file=file)
+ elif output == "txt":
+ WriteTXT(ResultWriter).write_result(result, file=file)
+ else:
+ return "Please select an output method!"
diff --git a/app/asr_models/openai_whisper_engine.py b/app/asr_models/openai_whisper_engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d844190252e321b48c6f2e4c4a7fbba77f18b5
--- /dev/null
+++ b/app/asr_models/openai_whisper_engine.py
@@ -0,0 +1,96 @@
+import time
+from io import StringIO
+from threading import Thread
+from typing import BinaryIO, Union
+
+import torch
+import whisper
+from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
+
+from app.asr_models.asr_model import ASRModel
+from app.config import CONFIG
+
+
+class OpenAIWhisperASR(ASRModel):
+
+ def load_model(self):
+
+ if torch.cuda.is_available():
+ self.model = whisper.load_model(
+ name=CONFIG.MODEL_NAME,
+ download_root=CONFIG.MODEL_PATH
+ ).cuda()
+ else:
+ self.model = whisper.load_model(
+ name=CONFIG.MODEL_NAME,
+ download_root=CONFIG.MODEL_PATH
+ )
+
+ Thread(target=self.monitor_idleness, daemon=True).start()
+
+ def transcribe(
+ self,
+ audio,
+ task: Union[str, None],
+ language: Union[str, None],
+ initial_prompt: Union[str, None],
+ vad_filter: Union[bool, None],
+ word_timestamps: Union[bool, None],
+ output,
+ ):
+ print("whisper")
+ self.last_activity_time = time.time()
+
+ with self.model_lock:
+ if self.model is None: self.load_model()
+
+ options_dict = {"task": task}
+ if language:
+ options_dict["language"] = language
+ if initial_prompt:
+ options_dict["initial_prompt"] = initial_prompt
+ if word_timestamps:
+ options_dict["word_timestamps"] = word_timestamps
+ with self.model_lock:
+ result = self.model.transcribe(audio, **options_dict)
+
+ output_file = StringIO()
+ self.write_result(result, output_file, output)
+ output_file.seek(0)
+
+ return output_file
+
+ def language_detection(self, audio):
+
+ self.last_activity_time = time.time()
+
+ with self.model_lock:
+ if self.model is None: self.load_model()
+
+ # load audio and pad/trim it to fit 30 seconds
+ audio = whisper.pad_or_trim(audio)
+
+ # make log-Mel spectrogram and move to the same device as the model
+ mel = whisper.log_mel_spectrogram(audio, self.model.dims.n_mels).to(self.model.device)
+
+ # detect the spoken language
+ with self.model_lock:
+ _, probs = self.model.detect_language(mel)
+ detected_lang_code = max(probs, key=probs.get)
+
+ return detected_lang_code, probs[max(probs)]
+
+ def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
+ options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
+ if output == "srt":
+ WriteSRT(ResultWriter).write_result(result, file=file, options=options)
+ elif output == "vtt":
+ WriteVTT(ResultWriter).write_result(result, file=file, options=options)
+ elif output == "tsv":
+ WriteTSV(ResultWriter).write_result(result, file=file, options=options)
+ elif output == "json":
+ WriteJSON(ResultWriter).write_result(result, file=file, options=options)
+ elif output == "txt":
+ WriteTXT(ResultWriter).write_result(result, file=file, options=options)
+ else:
+ return "Please select an output method!"
diff --git a/app/factory/asr_model_factory.py b/app/factory/asr_model_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f97fb53ae0a5ca20caec09dc40c388efcdb863a
--- /dev/null
+++ b/app/factory/asr_model_factory.py
@@ -0,0 +1,15 @@
+from app.asr_models.asr_model import ASRModel
+from app.asr_models.faster_whisper_engine import FasterWhisperASR
+from app.asr_models.openai_whisper_engine import OpenAIWhisperASR
+from app.config import CONFIG
+
+
+class ASRModelFactory:
+ @staticmethod
+ def create_asr_model() -> ASRModel:
+ if CONFIG.ASR_ENGINE == "openai_whisper":
+ return OpenAIWhisperASR()
+ elif CONFIG.ASR_ENGINE == "faster_whisper":
+ return FasterWhisperASR()
+ else:
+ raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}")
diff --git a/app/faster_whisper/core.py b/app/faster_whisper/core.py
deleted file mode 100644
index f273309b0fa78ea45d7f2c8e2cb1e380bea52e46..0000000000000000000000000000000000000000
--- a/app/faster_whisper/core.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import gc
-import time
-from io import StringIO
-from threading import Lock, Thread
-from typing import BinaryIO, Union
-
-import torch
-import whisper
-from faster_whisper import WhisperModel
-
-from app.config import CONFIG
-from app.faster_whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
-
-model = None
-model_lock = Lock()
-last_activity_time = time.time()
-
-
-def monitor_idleness():
- global model
- if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
- while True:
- time.sleep(15)
- if time.time() - last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
- with model_lock:
- release_model()
- break
-
-
-def load_model():
- global model, device, model_quantization
-
- model = WhisperModel(
- model_size_or_path=CONFIG.MODEL_NAME,
- device=CONFIG.DEVICE,
- compute_type=CONFIG.MODEL_QUANTIZATION,
- download_root=CONFIG.MODEL_PATH
- )
-
- Thread(target=monitor_idleness, daemon=True).start()
-
-
-load_model()
-
-
-def release_model():
- global model
- del model
- torch.cuda.empty_cache()
- gc.collect()
- model = None
- print("Model unloaded due to timeout")
-
-
-def transcribe(
- audio,
- task: Union[str, None],
- language: Union[str, None],
- initial_prompt: Union[str, None],
- vad_filter: Union[bool, None],
- word_timestamps: Union[bool, None],
- output,
-):
- global last_activity_time
- last_activity_time = time.time()
-
- with model_lock:
- if model is None: load_model()
-
- options_dict = {"task": task}
- if language:
- options_dict["language"] = language
- if initial_prompt:
- options_dict["initial_prompt"] = initial_prompt
- if vad_filter:
- options_dict["vad_filter"] = True
- if word_timestamps:
- options_dict["word_timestamps"] = True
- with model_lock:
- segments = []
- text = ""
- segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict)
- for segment in segment_generator:
- segments.append(segment)
- text = text + segment.text
- result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}
-
- output_file = StringIO()
- write_result(result, output_file, output)
- output_file.seek(0)
-
- return output_file
-
-
-def language_detection(audio):
- global last_activity_time
- last_activity_time = time.time()
-
- with model_lock:
- if model is None: load_model()
-
- # load audio and pad/trim it to fit 30 seconds
- audio = whisper.pad_or_trim(audio)
-
- # detect the spoken language
- with model_lock:
- segments, info = model.transcribe(audio, beam_size=5)
- detected_lang_code = info.language
- detected_language_confidence = info.language_probability
-
- return detected_lang_code, detected_language_confidence
-
-
-def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
- if output == "srt":
- WriteSRT(ResultWriter).write_result(result, file=file)
- elif output == "vtt":
- WriteVTT(ResultWriter).write_result(result, file=file)
- elif output == "tsv":
- WriteTSV(ResultWriter).write_result(result, file=file)
- elif output == "json":
- WriteJSON(ResultWriter).write_result(result, file=file)
- elif output == "txt":
- WriteTXT(ResultWriter).write_result(result, file=file)
- else:
- return "Please select an output method!"
diff --git a/app/openai_whisper/core.py b/app/openai_whisper/core.py
deleted file mode 100644
index d8e7d5f0a0509f93fab789e77fd8c9141e162d50..0000000000000000000000000000000000000000
--- a/app/openai_whisper/core.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import gc
-import time
-from io import StringIO
-from threading import Lock, Thread
-from typing import BinaryIO, Union
-
-import torch
-import whisper
-from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
-
-from app.config import CONFIG
-
-model = None
-model_lock = Lock()
-last_activity_time = time.time()
-
-
-def monitor_idleness():
- global model
- if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
- while True:
- time.sleep(15)
- if time.time() - last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
- with model_lock:
- release_model()
- break
-
-
-def load_model():
- global model
-
- if torch.cuda.is_available():
- model = whisper.load_model(
- name=CONFIG.MODEL_NAME,
- download_root=CONFIG.MODEL_PATH
- ).cuda()
- else:
- model = whisper.load_model(
- name=CONFIG.MODEL_NAME,
- download_root=CONFIG.MODEL_PATH
- )
-
- Thread(target=monitor_idleness, daemon=True).start()
-
-
-load_model()
-
-
-def release_model():
- global model
- del model
- torch.cuda.empty_cache()
- gc.collect()
- model = None
- print("Model unloaded due to timeout")
-
-
-def transcribe(
- audio,
- task: Union[str, None],
- language: Union[str, None],
- initial_prompt: Union[str, None],
- vad_filter: Union[bool, None],
- word_timestamps: Union[bool, None],
- output,
-):
- global last_activity_time
- last_activity_time = time.time()
-
- with model_lock:
- if model is None: load_model()
-
- options_dict = {"task": task}
- if language:
- options_dict["language"] = language
- if initial_prompt:
- options_dict["initial_prompt"] = initial_prompt
- if word_timestamps:
- options_dict["word_timestamps"] = word_timestamps
- with model_lock:
- result = model.transcribe(audio, **options_dict)
-
- output_file = StringIO()
- write_result(result, output_file, output)
- output_file.seek(0)
-
- return output_file
-
-
-def language_detection(audio):
- global last_activity_time
- last_activity_time = time.time()
-
- with model_lock:
- if model is None: load_model()
-
- # load audio and pad/trim it to fit 30 seconds
- audio = whisper.pad_or_trim(audio)
-
- # make log-Mel spectrogram and move to the same device as the model
- mel = whisper.log_mel_spectrogram(audio, model.dims.n_mels).to(model.device)
-
- # detect the spoken language
- with model_lock:
- _, probs = model.detect_language(mel)
- detected_lang_code = max(probs, key=probs.get)
-
- return detected_lang_code, probs[max(probs)]
-
-
-def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
- options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
- if output == "srt":
- WriteSRT(ResultWriter).write_result(result, file=file, options=options)
- elif output == "vtt":
- WriteVTT(ResultWriter).write_result(result, file=file, options=options)
- elif output == "tsv":
- WriteTSV(ResultWriter).write_result(result, file=file, options=options)
- elif output == "json":
- WriteJSON(ResultWriter).write_result(result, file=file, options=options)
- elif output == "txt":
- WriteTXT(ResultWriter).write_result(result, file=file, options=options)
- else:
- return "Please select an output method!"
diff --git a/app/faster_whisper/utils.py b/app/utils.py
similarity index 67%
rename from app/faster_whisper/utils.py
rename to app/utils.py
index 034c63937a304372999f0b9b60793e04208e5d5e..0b51281a0793aef2063ad475fd8a42b05ab94481 100644
--- a/app/faster_whisper/utils.py
+++ b/app/utils.py
@@ -1,9 +1,13 @@
import json
import os
-from typing import TextIO
+from typing import TextIO, BinaryIO
+import ffmpeg
+import numpy as np
from faster_whisper.utils import format_timestamp
+from app.config import CONFIG
+
class ResultWriter:
extension: str
@@ -85,3 +89,36 @@ class WriteJSON(ResultWriter):
def write_result(self, result: dict, file: TextIO):
json.dump(result, file)
+
+
+def load_audio(file: BinaryIO, encode=True, sr: int = CONFIG.SAMPLE_RATE):
+ """
+ Open an audio file object and read as mono waveform, resampling as necessary.
+ Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
+ Parameters
+ ----------
+ file: BinaryIO
+ The audio file like object
+ encode: Boolean
+ If true, encode audio stream to WAV before sending to whisper
+ sr: int
+ The sample rate to resample the audio if necessary
+ Returns
+ -------
+ A NumPy array containing the audio waveform, in float32 dtype.
+ """
+ if encode:
+ try:
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
+ out, _ = (
+ ffmpeg.input("pipe:", threads=0)
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
+ )
+ except ffmpeg.Error as e:
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+ else:
+ out = file.read()
+
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
diff --git a/app/webservice.py b/app/webservice.py
index 0e0dd2d21f6ccc3dc69ec2c4eb981e2846021160..7158645a467702a07baf8b997ce825712863ccf1 100644
--- a/app/webservice.py
+++ b/app/webservice.py
@@ -1,12 +1,10 @@
import importlib.metadata
import os
from os import path
-from typing import Annotated, BinaryIO, Optional, Union
+from typing import Annotated, Optional, Union
from urllib.parse import quote
import click
-import ffmpeg
-import numpy as np
import uvicorn
from fastapi import FastAPI, File, Query, UploadFile, applications
from fastapi.openapi.docs import get_swagger_ui_html
@@ -14,13 +12,13 @@ from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from whisper import tokenizer
-ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
-if ASR_ENGINE == "faster_whisper":
- from app.faster_whisper.core import language_detection, transcribe
-else:
- from app.openai_whisper.core import language_detection, transcribe
+from app.config import CONFIG
+from app.factory.asr_model_factory import ASRModelFactory
+from app.utils import load_audio
+
+asr_model = ASRModelFactory.create_asr_model()
+asr_model.load_model()
-SAMPLE_RATE = 16000
LANGUAGE_CODES = sorted(tokenizer.LANGUAGES.keys())
projectMetadata = importlib.metadata.metadata("whisper-asr-webservice")
@@ -67,20 +65,20 @@ async def asr(
bool | None,
Query(
description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
- include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False),
+ include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
),
] = False,
word_timestamps: bool = Query(default=False, description="Word level timestamps"),
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
):
- result = transcribe(
+ result = asr_model.transcribe(
load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
)
return StreamingResponse(
result,
media_type="text/plain",
headers={
- "Asr-Engine": ASR_ENGINE,
+ "Asr-Engine": CONFIG.ASR_ENGINE,
"Content-Disposition": f'attachment; filename="{quote(audio_file.filename)}.{output}"',
},
)
@@ -91,44 +89,11 @@ async def detect_language(
audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
):
- detected_lang_code, confidence = language_detection(load_audio(audio_file.file, encode))
+ detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code,
"confidence": confidence}
-def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
- """
- Open an audio file object and read as mono waveform, resampling as necessary.
- Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
- Parameters
- ----------
- file: BinaryIO
- The audio file like object
- encode: Boolean
- If true, encode audio stream to WAV before sending to whisper
- sr: int
- The sample rate to resample the audio if necessary
- Returns
- -------
- A NumPy array containing the audio waveform, in float32 dtype.
- """
- if encode:
- try:
- # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
- # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
- out, _ = (
- ffmpeg.input("pipe:", threads=0)
- .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
- .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
- )
- except ffmpeg.Error as e:
- raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
- else:
- out = file.read()
-
- return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
-
-
@click.command()
@click.option(
"-h",