Skip to content
Snippets Groups Projects
Select Git revision
  • 2f49d730e3768988aa86f578927ba0af2843b5b4
  • main default protected
2 results

webservice.py

Blame
  • webservice.py 5.54 KiB
    from fastapi import FastAPI, File, UploadFile, Query, applications
    from fastapi.responses import StreamingResponse, RedirectResponse
    from fastapi.staticfiles import StaticFiles
    from fastapi.openapi.docs import get_swagger_ui_html
    import whisper
    from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
    from whisper import tokenizer
    import os
    from os import path
    from pathlib import Path
    import ffmpeg
    from typing import BinaryIO, Union
    import numpy as np
    from io import StringIO
    from threading import Lock
    import torch
    import fastapi_offline_swagger_ui
    import importlib.metadata 
    
    SAMPLE_RATE=16000
    LANGUAGE_CODES=sorted(list(tokenizer.LANGUAGES.keys()))
    
    projectMetadata = importlib.metadata.metadata('whisper-asr-webservice')
    app = FastAPI(
        title=projectMetadata['Name'].title().replace('-', ' '),
        description=projectMetadata['Summary'],
        version=projectMetadata['Version'],
        contact={
            "url": projectMetadata['Home-page']
        },
        swagger_ui_parameters={"defaultModelsExpandDepth": -1},
        license_info={
            "name": "MIT License",
            "url": projectMetadata['License']
        }
    )
    
    assets_path = fastapi_offline_swagger_ui.__path__[0]
    if path.exists(assets_path + "/swagger-ui.css") and path.exists(assets_path + "/swagger-ui-bundle.js"):
        app.mount("/assets", StaticFiles(directory=assets_path), name="static")
        def swagger_monkey_patch(*args, **kwargs):
            return get_swagger_ui_html(
                *args,
                **kwargs,
                swagger_favicon_url="",
                swagger_css_url="/assets/swagger-ui.css",
                swagger_js_url="/assets/swagger-ui-bundle.js",
            )
        applications.get_swagger_ui_html = swagger_monkey_patch
    
    model_name= os.getenv("ASR_MODEL", "base")
    if torch.cuda.is_available():
        model = whisper.load_model(model_name).cuda()
    else:
        model = whisper.load_model(model_name)
    model_lock = Lock()
    
    @app.get("/", response_class=RedirectResponse, include_in_schema=False)
    async def index():
        return "/docs"
    
    @app.post("/asr", tags=["Endpoints"])
    def transcribe(
                    audio_file: UploadFile = File(...),
                    task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
                    language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
                    initial_prompt: Union[str, None] = Query(default=None),
                    output : Union[str, None] = Query(default="txt", enum=[ "txt", "vtt", "srt", "tsv", "json"]),
                    ):
    
        result = run_asr(audio_file.file, task, language, initial_prompt)
        filename = audio_file.filename.split('.')[0]
        myFile = StringIO()
        if(output == "srt"):
            WriteSRT(ResultWriter).write_result(result, file = myFile)
        elif(output == "vtt"):
            WriteVTT(ResultWriter).write_result(result, file = myFile)
        elif(output == "tsv"):
            WriteTSV(ResultWriter).write_result(result, file = myFile)
        elif(output == "json"):
            WriteJSON(ResultWriter).write_result(result, file = myFile)
        elif(output == "txt"):
            WriteTXT(ResultWriter).write_result(result, file = myFile)
        else:
            return 'Please select an output method!'
        myFile.seek(0)
        return StreamingResponse(myFile, media_type="text/plain", 
                                headers={'Content-Disposition': f'attachment; filename="{filename}.{output}"'})
    
    
    @app.post("/detect-language", tags=["Endpoints"])
    def language_detection(
                    audio_file: UploadFile = File(...),
                    ):
    
        # load audio and pad/trim it to fit 30 seconds
        audio = load_audio(audio_file.file)
        audio = whisper.pad_or_trim(audio)
    
        # make log-Mel spectrogram and move to the same device as the model
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
        # detect the spoken language
        with model_lock:
            _, probs = model.detect_language(mel)
        detected_lang_code = max(probs, key=probs.get)
        
        result = { "detected_language": tokenizer.LANGUAGES[detected_lang_code],
                  "language_code" : detected_lang_code }
    
        return result
    
    
    def run_asr(file: BinaryIO, task: Union[str, None], language: Union[str, None], initial_prompt: Union[str, None] ):
        audio = load_audio(file)
        options_dict = {"task" : task}
        if language:
            options_dict["language"] = language    
        if initial_prompt:
            options_dict["initial_prompt"] = initial_prompt    
        with model_lock:   
            result = model.transcribe(audio, **options_dict)
            
        return result
    
    
    def load_audio(file: BinaryIO, sr: int = SAMPLE_RATE):
        """
        Open an audio file object and read as mono waveform, resampling as necessary.
        Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
        Parameters
        ----------
        file: BinaryIO
            The audio file like object
        sr: int
            The sample rate to resample the audio if necessary
        Returns
        -------
        A NumPy array containing the audio waveform, in float32 dtype.
        """
        try:
            # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
            # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
            out, _ = (
                ffmpeg.input("pipe:", threads=0)
                .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
                .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
            )
        except ffmpeg.Error as e:
            raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
    
        return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0