Skip to content
Snippets Groups Projects
Select Git revision
  • b4d3f03aff3a9addac378073a52436baed4965bc
  • main default protected
2 results

webservice.py

Blame
  • openai_whisper_engine.py 3.21 KiB
    import time
    from io import StringIO
    from threading import Thread
    from typing import BinaryIO, Union
    
    import torch
    import whisper
    from whisper.utils import ResultWriter, WriteJSON, WriteSRT, WriteTSV, WriteTXT, WriteVTT
    
    from app.asr_models.asr_model import ASRModel
    from app.config import CONFIG
    
    
    class OpenAIWhisperASR(ASRModel):
    
        def load_model(self):
    
            if torch.cuda.is_available():
                self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH).cuda()
            else:
                self.model = whisper.load_model(name=CONFIG.MODEL_NAME, download_root=CONFIG.MODEL_PATH)
    
            Thread(target=self.monitor_idleness, daemon=True).start()
    
        def transcribe(
            self,
            audio,
            task: Union[str, None],
            language: Union[str, None],
            initial_prompt: Union[str, None],
            vad_filter: Union[bool, None],
            word_timestamps: Union[bool, None],
            options: Union[dict, None],
            output,
        ):
            # Mark transcription as active
            self.transcription_active = True
            # Update last activity timestamp
            self.last_activity_time = time.time()
    
            with self.model_lock:
                if self.model is None:
                    self.load_model()
    
            options_dict = {"task": task}
            if language:
                options_dict["language"] = language
            if initial_prompt:
                options_dict["initial_prompt"] = initial_prompt
            if word_timestamps:
                options_dict["word_timestamps"] = word_timestamps
            with self.model_lock:
                result = self.model.transcribe(audio, **options_dict)
    
            output_file = StringIO()
            self.write_result(result, output_file, output)
            output_file.seek(0)
            # Mark transcription as completed
            self.transcription_active = False
    
            return output_file
    
        def language_detection(self, audio):
    
            self.last_activity_time = time.time()
    
            with self.model_lock:
                if self.model is None:
                    self.load_model()
    
            # load audio and pad/trim it to fit 30 seconds
            audio = whisper.pad_or_trim(audio)
    
            # make log-Mel spectrogram and move to the same device as the model
            mel = whisper.log_mel_spectrogram(audio, self.model.dims.n_mels).to(self.model.device)
    
            # detect the spoken language
            with self.model_lock:
                _, probs = self.model.detect_language(mel)
            detected_lang_code = max(probs, key=probs.get)
    
            return detected_lang_code, probs[max(probs)]
    
        def write_result(self, result: dict, file: BinaryIO, output: Union[str, None]):
            options = {"max_line_width": 1000, "max_line_count": 10, "highlight_words": False}
            if output == "srt":
                WriteSRT(ResultWriter).write_result(result, file=file, options=options)
            elif output == "vtt":
                WriteVTT(ResultWriter).write_result(result, file=file, options=options)
            elif output == "tsv":
                WriteTSV(ResultWriter).write_result(result, file=file, options=options)
            elif output == "json":
                WriteJSON(ResultWriter).write_result(result, file=file, options=options)
            else:
                WriteTXT(ResultWriter).write_result(result, file=file, options=options)