Skip to content
Snippets Groups Projects
Select Git revision
  • e38082e9781dd97b133388863f21809aaed19810
  • main default protected
2 results

asr_model.py

Blame
  • user avatar
    Subliminal Guy authored
    e38082e9
    History
    asr_model.py 2.35 KiB
    import gc
    import time
    from abc import ABC, abstractmethod
    from threading import Lock
    from typing import Union
    
    import torch
    
    from app.config import CONFIG
    
    
    class ASRModel(ABC):
        """
        Abstract base class for ASR (Automatic Speech Recognition) models.
        """
    
        model = None
        model_lock = Lock()
        last_activity_time = time.time()
    
        def __init__(self):
            # Flag indicating if a transcription is currently running
            self.transcription_active = False
    
        @abstractmethod
        def load_model(self):
            """
            Loads the model from the specified path.
            """
            pass
    
        @abstractmethod
        def transcribe(
            self,
            audio,
            task: Union[str, None],
            language: Union[str, None],
            initial_prompt: Union[str, None],
            vad_filter: Union[bool, None],
            word_timestamps: Union[bool, None],
            options: Union[dict, None],
            output,
        ):
            """
            Perform transcription on the given audio file.
            """
            pass
    
        @abstractmethod
        def language_detection(self, audio):
            """
            Perform language detection on the given audio file.
            """
            pass
    
        def monitor_idleness(self):
            """
            Monitors the idleness of the ASR model and releases the model if it has been idle for too long.
            """
            if CONFIG.MODEL_IDLE_TIMEOUT <= 0:
                return
            while True:
                time.sleep(15)
                if time.time() - self.last_activity_time > CONFIG.MODEL_IDLE_TIMEOUT:
                    with self.model_lock:
                        self.release_model()
                        break
    
        def release_model(self):
            """
            Unloads the model from memory and clears any cached GPU memory.
            """
            del self.model
            torch.cuda.empty_cache()
            gc.collect()
            self.model = None
            print("Model unloaded due to timeout")
        
        @property
        def is_transcribing(self) -> bool:
            """
            Returns True if a transcription is currently running.
            """
            return self.model_lock.locked()
        
        @property
        def is_model_loaded(self) -> bool:
            """
            Returns True if the model is loaded in memory.
            """
            model_attr = self.model
            if isinstance(model_attr, dict):
                return model_attr.get('whisperx') is not None
            return model_attr is not None