Skip to content
Snippets Groups Projects
Commit f6507969 authored by Ahmet Öner's avatar Ahmet Öner
Browse files

Refactor classes, Add comments, implement abstract methods, and add factory...

Refactor classes, Add comments, implement abstract methods, and add factory method for engine selection
parent 53779e92
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,10 @@ Unreleased
- Added detection confidence to langauge detection endpoint
- Set mel generation to adjust n_dims automatically to match the loaded model
### Added
- Refactor classes, Add comments, implement abstract methods, and add factory method for engine selection
[1.6.0] (2024-10-06)
--------------------
......
import gc
import time
from abc import ABC, abstractmethod
from threading import Lock
from typing import Union
import torch
from app.config import CONFIG
class ASRModel(ABC):
"""
Abstract base class for ASR (Automatic Speech Recognition) models.
"""
model = None
model_lock = Lock()
last_activity_time = time.time()
def __init__(self):
pass
@abstractmethod
def load_model(self):
"""
Loads the model from the specified path.
"""
pass
@abstractmethod
def transcribe(self,
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None]
):
"""
Perform transcription on the given audio file.
"""
pass
@abstractmethod
def language_detection(self, audio):
"""
Perform language detection on the given audio file.
"""
pass
def monitor_idleness(self):
"""
Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
"""
if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
while True:
time.sleep(15)
if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
with self.model_lock:
self.release_model()
break
def release_model(self):
"""
Unloads the model from memory and clears any cached GPU memory.
"""
del self.model
torch.cuda.empty_cache()
gc.collect()
self.model = None
print("Model unloaded due to timeout")
import time
from io import StringIO
from threading import Thread
from typing import BinaryIO, Union
import whisper
from faster_whisper import WhisperModel
from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
from app.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
class FasterWhisperASR(ASRModel):
def load_model(self):
self.model = WhisperModel(
model_size_or_path=CONFIG.MODEL_NAME,
device=CONFIG.DEVICE,
compute_type=CONFIG.MODEL_QUANTIZATION,
download_root=CONFIG.MODEL_PATH
)
Thread(target=self.monitor_idleness, daemon=True).start()
def transcribe(
self,
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
output,
):
print("faster whisper")
self.last_activity_time = time.time()
with self.model_lock:
if self.model is None: self.load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
if vad_filter:
options_dict["vad_filter"] = True
if word_timestamps:
options_dict["word_timestamps"] = True
with self.model_lock:
segments = []
text = ""
segment_generator, info = self.model.transcribe(audio, beam_size=5, **options_dict)
for segment in segment_generator:
segments.append(segment)
text = text + segment.text
result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)
return output_file
def language_detection(self, audio):
self.last_activity_time = time.time()
with self.model_lock:
if self.model is None: self.load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# detect the spoken language
with self.model_lock:
segments, info = self.model.transcribe(audio, beam_size=5)
detected_lang_code = info.language
detected_language_confidence = info.language_probability
return detected_lang_code, detected_language_confidence
def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file)
elif output == "vtt":
WriteVTT(ResultWriter).write_result(result, file=file)
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file)
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file)
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file)
else:
return "Please select an output method!"
import time
from io import StringIO
from threading import Thread
from typing import BinaryIO, Union
import torch
import whisper
from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
from app.asr_models.asr_model import ASRModel
from app.config import CONFIG
class OpenAIWhisperASR(ASRModel):
def load_model(self):
if torch.cuda.is_available():
self.model = whisper.load_model(
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
).cuda()
else:
self.model = whisper.load_model(
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
)
Thread(target=self.monitor_idleness, daemon=True).start()
def transcribe(
self,
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
output,
):
print("whisper")
self.last_activity_time = time.time()
with self.model_lock:
if self.model is None: self.load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
if word_timestamps:
options_dict["word_timestamps"] = word_timestamps
with self.model_lock:
result = self.model.transcribe(audio, **options_dict)
output_file = StringIO()
self.write_result(result, output_file, output)
output_file.seek(0)
return output_file
def language_detection(self, audio):
self.last_activity_time = time.time()
with self.model_lock:
if self.model is None: self.load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio, self.model.dims.n_mels).to(self.model.device)
# detect the spoken language
with self.model_lock:
_, probs = self.model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get)
return detected_lang_code, probs[max(probs)]
def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file, options=options)
elif output == "vtt":
WriteVTT(ResultWriter).write_result(result, file=file, options=options)
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file, options=options)
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file, options=options)
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file, options=options)
else:
return "Please select an output method!"
from app.asr_models.asr_model import ASRModel
from app.asr_models.faster_whisper_engine import FasterWhisperASR
from app.asr_models.openai_whisper_engine import OpenAIWhisperASR
from app.config import CONFIG
class ASRModelFactory:
@staticmethod
def create_asr_model() -> ASRModel:
if CONFIG.ASR_ENGINE == "openai_whisper":
return OpenAIWhisperASR()
elif CONFIG.ASR_ENGINE == "faster_whisper":
return FasterWhisperASR()
else:
raise ValueError(f"Unsupported ASR engine: {CONFIG.ASR_ENGINE}")
import gc
import time
from io import StringIO
from threading import Lock, Thread
from typing import BinaryIO, Union
import torch
import whisper
from faster_whisper import WhisperModel
from app.config import CONFIG
from app.faster_whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
model = None
model_lock = Lock()
last_activity_time = time.time()
def monitor_idleness():
global model
if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
while True:
time.sleep(15)
if time.time() - last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
with model_lock:
release_model()
break
def load_model():
global model, device, model_quantization
model = WhisperModel(
model_size_or_path=CONFIG.MODEL_NAME,
device=CONFIG.DEVICE,
compute_type=CONFIG.MODEL_QUANTIZATION,
download_root=CONFIG.MODEL_PATH
)
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
torch.cuda.empty_cache()
gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
output,
):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if model is None: load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
if vad_filter:
options_dict["vad_filter"] = True
if word_timestamps:
options_dict["word_timestamps"] = True
with model_lock:
segments = []
text = ""
segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict)
for segment in segment_generator:
segments.append(segment)
text = text + segment.text
result = {"language": options_dict.get("language", info.language), "segments": segments, "text": text}
output_file = StringIO()
write_result(result, output_file, output)
output_file.seek(0)
return output_file
def language_detection(audio):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if model is None: load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# detect the spoken language
with model_lock:
segments, info = model.transcribe(audio, beam_size=5)
detected_lang_code = info.language
detected_language_confidence = info.language_probability
return detected_lang_code, detected_language_confidence
def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file)
elif output == "vtt":
WriteVTT(ResultWriter).write_result(result, file=file)
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file)
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file)
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file)
else:
return "Please select an output method!"
import gc
import time
from io import StringIO
from threading import Lock, Thread
from typing import BinaryIO, Union
import torch
import whisper
from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
from app.config import CONFIG
model = None
model_lock = Lock()
last_activity_time = time.time()
def monitor_idleness():
global model
if CONFIG.MODEL_IDLE_TIMEOUT <= 0: return
while True:
time.sleep(15)
if time.time() - last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
with model_lock:
release_model()
break
def load_model():
global model
if torch.cuda.is_available():
model = whisper.load_model(
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
).cuda()
else:
model = whisper.load_model(
name=CONFIG.MODEL_NAME,
download_root=CONFIG.MODEL_PATH
)
Thread(target=monitor_idleness, daemon=True).start()
load_model()
def release_model():
global model
del model
torch.cuda.empty_cache()
gc.collect()
model = None
print("Model unloaded due to timeout")
def transcribe(
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
vad_filter: Union[bool, None],
word_timestamps: Union[bool, None],
output,
):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if model is None: load_model()
options_dict = {"task": task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
if word_timestamps:
options_dict["word_timestamps"] = word_timestamps
with model_lock:
result = model.transcribe(audio, **options_dict)
output_file = StringIO()
write_result(result, output_file, output)
output_file.seek(0)
return output_file
def language_detection(audio):
global last_activity_time
last_activity_time = time.time()
with model_lock:
if model is None: load_model()
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio, model.dims.n_mels).to(model.device)
# detect the spoken language
with model_lock:
_, probs = model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get)
return detected_lang_code, probs[max(probs)]
def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
if output == "srt":
WriteSRT(ResultWriter).write_result(result, file=file, options=options)
elif output == "vtt":
WriteVTT(ResultWriter).write_result(result, file=file, options=options)
elif output == "tsv":
WriteTSV(ResultWriter).write_result(result, file=file, options=options)
elif output == "json":
WriteJSON(ResultWriter).write_result(result, file=file, options=options)
elif output == "txt":
WriteTXT(ResultWriter).write_result(result, file=file, options=options)
else:
return "Please select an output method!"
import json
import os
from typing import TextIO
from typing import TextIO, BinaryIO
import ffmpeg
import numpy as np
from faster_whisper.utils import format_timestamp
from app.config import CONFIG
class ResultWriter:
extension: str
......@@ -85,3 +89,36 @@ class WriteJSON(ResultWriter):
def write_result(self, result: dict, file: TextIO):
json.dump(result, file)
def load_audio(file: BinaryIO, encode=True, sr: int = CONFIG.SAMPLE_RATE):
"""
Open an audio file object and read as mono waveform, resampling as necessary.
Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
Parameters
----------
file: BinaryIO
The audio file like object
encode: Boolean
If true, encode audio stream to WAV before sending to whisper
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
if encode:
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input("pipe:", threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
out = file.read()
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
import importlib.metadata
import os
from os import path
from typing import Annotated, BinaryIO, Optional, Union
from typing import Annotated, Optional, Union
from urllib.parse import quote
import click
import ffmpeg
import numpy as np
import uvicorn
from fastapi import FastAPI, File, Query, UploadFile, applications
from fastapi.openapi.docs import get_swagger_ui_html
......@@ -14,13 +12,13 @@ from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from whisper import tokenizer
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
if ASR_ENGINE == "faster_whisper":
from app.faster_whisper.core import language_detection, transcribe
else:
from app.openai_whisper.core import language_detection, transcribe
from app.config import CONFIG
from app.factory.asr_model_factory import ASRModelFactory
from app.utils import load_audio
asr_model = ASRModelFactory.create_asr_model()
asr_model.load_model()
SAMPLE_RATE = 16000
LANGUAGE_CODES = sorted(tokenizer.LANGUAGES.keys())
projectMetadata = importlib.metadata.metadata("whisper-asr-webservice")
......@@ -67,20 +65,20 @@ async def asr(
bool | None,
Query(
description="Enable the voice activity detection (VAD) to filter out parts of the audio without speech",
include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False),
include_in_schema=(True if CONFIG.ASR_ENGINE == "faster_whisper" else False),
),
] = False,
word_timestamps: bool = Query(default=False, description="Word level timestamps"),
output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]),
):
result = transcribe(
result = asr_model.transcribe(
load_audio(audio_file.file, encode), task, language, initial_prompt, vad_filter, word_timestamps, output
)
return StreamingResponse(
result,
media_type="text/plain",
headers={
"Asr-Engine": ASR_ENGINE,
"Asr-Engine": CONFIG.ASR_ENGINE,
"Content-Disposition": f'attachment; filename="{quote(audio_file.filename)}.{output}"',
},
)
......@@ -91,44 +89,11 @@ async def detect_language(
audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
):
detected_lang_code, confidence = language_detection(load_audio(audio_file.file, encode))
detected_lang_code, confidence = asr_model.language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code,
"confidence": confidence}
def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
"""
Open an audio file object and read as mono waveform, resampling as necessary.
Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
Parameters
----------
file: BinaryIO
The audio file like object
encode: Boolean
If true, encode audio stream to WAV before sending to whisper
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
if encode:
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input("pipe:", threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
out = file.read()
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
@click.command()
@click.option(
"-h",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment