Skip to content
Snippets Groups Projects
Commit 0f13c4db authored by Ahmet Öner's avatar Ahmet Öner
Browse files

Seperate Faster Whisper and Openai Whisper with env variable

parent 8959328d
No related branches found
No related tags found
No related merge requests found
import os
from typing import BinaryIO, Union
from io import StringIO
from threading import Lock
import torch
import whisper
from .utils import model_converter, ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
from faster_whisper import WhisperModel
model_name= os.getenv("ASR_MODEL", "base")
model_path = os.path.join("/root/.cache/faster_whisper", model_name)
model_converter(model_name, model_path)
if torch.cuda.is_available():
model = WhisperModel(model_path, device="cuda", compute_type="float16")
else:
model = WhisperModel(model_path, device="cpu", compute_type="int8")
model_lock = Lock()
def transcribe(
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
output,
):
options_dict = {"task" : task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
with model_lock:
segments = []
text = ""
i = 0
segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict)
for segment in segment_generator:
segments.append(segment)
text = text + segment.text
result = {
"language": options_dict.get("language", info.language),
"segments": segments,
"text": text,
}
outputFile = StringIO()
write_result(result, outputFile, output)
outputFile.seek(0)
return outputFile
def language_detection(audio):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# detect the spoken language
with model_lock:
segments, info = model.transcribe(audio, beam_size=5)
detected_lang_code = info.language
return detected_lang_code
def write_result(
result: dict, file: BinaryIO, output: Union[str, None]
):
if(output == "srt"):
WriteSRT(ResultWriter).write_result(result, file = file)
elif(output == "vtt"):
WriteVTT(ResultWriter).write_result(result, file = file)
elif(output == "tsv"):
WriteTSV(ResultWriter).write_result(result, file = file)
elif(output == "json"):
WriteJSON(ResultWriter).write_result(result, file = file)
elif(output == "txt"):
WriteTXT(ResultWriter).write_result(result, file = file)
else:
return 'Please select an output method!'
import os
from typing import BinaryIO, Union
from io import StringIO
from threading import Lock
import torch
import whisper
from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
model_name= os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
def transcribe(
audio,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
output
):
options_dict = {"task" : task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
with model_lock:
result = model.transcribe(audio, **options_dict)
outputFile = StringIO()
write_result(result, outputFile, output)
outputFile.seek(0)
return outputFile
def language_detection(audio):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
with model_lock:
_, probs = model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get)
return detected_lang_code
def write_result(
result: dict, file: BinaryIO, output: Union[str, None]
):
if(output == "srt"):
WriteSRT(ResultWriter).write_result(result, file = file)
elif(output == "vtt"):
WriteVTT(ResultWriter).write_result(result, file = file)
elif(output == "tsv"):
WriteTSV(ResultWriter).write_result(result, file = file)
elif(output == "json"):
WriteJSON(ResultWriter).write_result(result, file = file)
elif(output == "txt"):
WriteTXT(ResultWriter).write_result(result, file = file)
else:
return 'Please select an output method!'
import os
from os import path
import importlib.metadata
from typing import BinaryIO, Union
import numpy as np
import ffmpeg
from fastapi import FastAPI, File, UploadFile, Query, applications
from fastapi.responses import StreamingResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.openapi.docs import get_swagger_ui_html
import whisper
from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
from whisper import tokenizer
from faster_whisper import WhisperModel
from .faster_whisper.utils import (
model_converter as faster_whisper_model_converter,
ResultWriter as faster_whisper_ResultWriter,
WriteTXT as faster_whisper_WriteTXT,
WriteSRT as faster_whisper_WriteSRT,
WriteVTT as faster_whisper_WriteVTT,
WriteTSV as faster_whisper_WriteTSV,
WriteJSON as faster_whisper_WriteJSON,
)
import os
from os import path
from pathlib import Path
import ffmpeg
from typing import BinaryIO, Union
import numpy as np
from io import StringIO
from threading import Lock
import torch
import importlib.metadata
ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper")
if ASR_ENGINE == "faster_whisper":
from .faster_whisper.core import transcribe, language_detection
else:
from .openai_whisper.core import transcribe, language_detection
SAMPLE_RATE=16000
LANGUAGE_CODES=sorted(list(tokenizer.LANGUAGES.keys()))
......@@ -57,30 +48,12 @@ if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/
)
applications.get_swagger_ui_html = swagger_monkey_patch
whisper_model_name= os.getenv("ASR_MODEL", "base")
faster_whisper_model_path = os.path.join("/root/.cache/faster_whisper", whisper_model_name)
faster_whisper_model_converter(whisper_model_name, faster_whisper_model_path)
if torch.cuda.is_available():
whisper_model = whisper.load_model(whisper_model_name).cuda()
faster_whisper_model = WhisperModel(faster_whisper_model_path, device="cuda", compute_type="float16")
else:
whisper_model = whisper.load_model(whisper_model_name)
faster_whisper_model = WhisperModel(faster_whisper_model_path)
model_lock = Lock()
def get_model(method: str = "openai-whisper"):
if method == "faster-whisper":
return faster_whisper_model
return whisper_model
@app.get("/", response_class=RedirectResponse, include_in_schema=False)
async def index():
return "/docs"
@app.post("/asr", tags=["Endpoints"])
def transcribe(
method: Union[str, None] = Query(default="openai-whisper", enum=["openai-whisper", "faster-whisper"]),
def asr(
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
initial_prompt: Union[str, None] = Query(default=None),
......@@ -88,103 +61,22 @@ def transcribe(
encode : bool = Query(default=True, description="Encode audio first through ffmpeg"),
output : Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"])
):
result = run_asr(audio_file.file, task, language, initial_prompt, method, encode)
filename = audio_file.filename.split('.')[0]
myFile = StringIO()
write_result(result, myFile, output, method)
myFile.seek(0)
return StreamingResponse(myFile, media_type="text/plain", headers={'Content-Disposition': f'attachment; filename="{filename}.{output}"'})
result = transcribe(load_audio(audio_file.file, encode), task, language, initial_prompt, output)
return StreamingResponse(
result,
media_type="text/plain",
headers={
'Asr-Engine': ASR_ENGINE,
'Content-Disposition': f'attachment; filename="{audio_file.filename}.{output}"'
})
@app.post("/detect-language", tags=["Endpoints"])
def language_detection(
def detect_language(
audio_file: UploadFile = File(...),
method: Union[str, None] = Query(default="openai-whisper", enum=["openai-whisper", "faster-whisper"]),
encode : bool = Query(default=True, description="Encode audio first through ffmpeg")
):
# load audio and pad/trim it to fit 30 seconds
audio = load_audio(audio_file.file, encode)
audio = whisper.pad_or_trim(audio)
# detect the spoken language
with model_lock:
model = get_model(method)
if method == "faster-whisper":
segments, info = model.transcribe(audio, beam_size=5)
detected_lang_code = info.language
else:
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get)
result = { "detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code" : detected_lang_code }
return result
def run_asr(
file: BinaryIO,
task: Union[str, None],
language: Union[str, None],
initial_prompt: Union[str, None],
method: Union[str, None],
encode=True
):
audio = load_audio(file, encode)
options_dict = {"task" : task}
if language:
options_dict["language"] = language
if initial_prompt:
options_dict["initial_prompt"] = initial_prompt
with model_lock:
model = get_model(method)
if method == "faster-whisper":
segments = []
text = ""
i = 0
segment_generator, info = model.transcribe(audio, beam_size=5, **options_dict)
for segment in segment_generator:
segments.append(segment)
text = text + segment.text
result = {
"language": options_dict.get("language", info.language),
"segments": segments,
"text": text,
}
else:
result = model.transcribe(audio, **options_dict)
return result
def write_result(
result: dict, file: BinaryIO, output: Union[str, None], method: Union[str, None]
):
if method == "faster-whisper":
if(output == "srt"):
faster_whisper_WriteSRT(ResultWriter).write_result(result, file = file)
elif(output == "vtt"):
faster_whisper_WriteVTT(ResultWriter).write_result(result, file = file)
elif(output == "tsv"):
faster_whisper_WriteTSV(ResultWriter).write_result(result, file = file)
elif(output == "json"):
faster_whisper_WriteJSON(ResultWriter).write_result(result, file = file)
elif(output == "txt"):
faster_whisper_WriteTXT(ResultWriter).write_result(result, file = file)
else:
return 'Please select an output method!'
else:
if(output == "srt"):
WriteSRT(ResultWriter).write_result(result, file = file)
elif(output == "vtt"):
WriteVTT(ResultWriter).write_result(result, file = file)
elif(output == "tsv"):
WriteTSV(ResultWriter).write_result(result, file = file)
elif(output == "json"):
WriteJSON(ResultWriter).write_result(result, file = file)
elif(output == "txt"):
WriteTXT(ResultWriter).write_result(result, file = file)
else:
return 'Please select an output method!'
detected_lang_code = language_detection(load_audio(audio_file.file, encode))
return { "detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code" : detected_lang_code }
def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment