Skip to content
Snippets Groups Projects
Select Git revision
  • 986a18af7931205d920ac5da1f2ccb9d8a8b8a2b
  • main default protected
2 results

core.py

Blame
  • core.py 2.17 KiB
    import os
    from io import StringIO
    from threading import Lock
    from typing import BinaryIO, Union
    
    import torch
    import whisper
    from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON
    
    model_name = os.getenv("ASR_MODEL", "base")
    if torch.cuda.is_available():
        model = whisper.load_model(model_name).cuda()
    else:
        model = whisper.load_model(model_name)
    model_lock = Lock()
    
    
    def transcribe(
            audio,
            task: Union[str, None],
            language: Union[str, None],
            initial_prompt: Union[str, None],
            word_timestamps: Union[bool, None],
            output
    ):
        options_dict = {"task": task}
        if language:
            options_dict["language"] = language
        if initial_prompt:
            options_dict["initial_prompt"] = initial_prompt
        with model_lock:
            result = model.transcribe(audio, **options_dict)
    
        output_file = StringIO()
        write_result(result, output_file, output)
        output_file.seek(0)
    
        return output_file
    
    
    def language_detection(audio):
        # load audio and pad/trim it to fit 30 seconds
        audio = whisper.pad_or_trim(audio)
    
        # make log-Mel spectrogram and move to the same device as the model
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
        # detect the spoken language
        with model_lock:
            _, probs = model.detect_language(mel)
        detected_lang_code = max(probs, key=probs.get)
    
        return detected_lang_code
    
    
    def write_result(
            result: dict, file: BinaryIO, output: Union[str, None]
    ):
        options = {
            'max_line_width': 1000,
            'max_line_count': 10,
            'highlight_words': False
        }
        if output == "srt":
            WriteSRT(ResultWriter).write_result(result, file=file, options=options)
        elif output == "vtt":
            WriteVTT(ResultWriter).write_result(result, file=file, options=options)
        elif output == "tsv":
            WriteTSV(ResultWriter).write_result(result, file=file, options=options)
        elif output == "json":
            WriteJSON(ResultWriter).write_result(result, file=file, options=options)
        elif output == "txt":
            WriteTXT(ResultWriter).write_result(result, file=file, options=options)
        else:
            return 'Please select an output method!'