Skip to content
Snippets Groups Projects
Unverified Commit 9e8f8e26 authored by Aidan Crowther's avatar Aidan Crowther Committed by GitHub
Browse files

Merge branch 'main' into add-vram-flush

parents 252e5a7a e5518bf1
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ Unreleased ...@@ -7,6 +7,7 @@ Unreleased
### Added ### Added
- Timeout configured to allow model to be unloaded when idle - Timeout configured to allow model to be unloaded when idle
- Added detection confidence to langauge detection endpoint
- Set mel generation to adjust n_dims automatically to match the loaded model - Set mel generation to adjust n_dims automatically to match the loaded model
[1.6.0] (2024-10-06) [1.6.0] (2024-10-06)
......
...@@ -112,8 +112,9 @@ def language_detection(audio): ...@@ -112,8 +112,9 @@ def language_detection(audio):
with model_lock: with model_lock:
segments, info = model.transcribe(audio, beam_size=5) segments, info = model.transcribe(audio, beam_size=5)
detected_lang_code = info.language detected_lang_code = info.language
detected_language_confidence = info.language_probability
return detected_lang_code return detected_lang_code, detected_language_confidence
def write_result(result: dict, file: BinaryIO, output: Union[str, None]): def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
......
...@@ -97,7 +97,7 @@ def language_detection(audio): ...@@ -97,7 +97,7 @@ def language_detection(audio):
_, probs = model.detect_language(mel) _, probs = model.detect_language(mel)
detected_lang_code = max(probs, key=probs.get) detected_lang_code = max(probs, key=probs.get)
return detected_lang_code return detected_lang_code, probs[max(probs)]
def write_result(result: dict, file: BinaryIO, output: Union[str, None]): def write_result(result: dict, file: BinaryIO, output: Union[str, None]):
......
...@@ -89,8 +89,8 @@ async def detect_language( ...@@ -89,8 +89,8 @@ async def detect_language(
audio_file: UploadFile = File(...), # noqa: B008 audio_file: UploadFile = File(...), # noqa: B008
encode: bool = Query(default=True, description="Encode audio first through FFmpeg"), encode: bool = Query(default=True, description="Encode audio first through FFmpeg"),
): ):
detected_lang_code = language_detection(load_audio(audio_file.file, encode)) detected_lang_code, confidence = language_detection(load_audio(audio_file.file, encode))
return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code} return {"detected_language": tokenizer.LANGUAGES[detected_lang_code], "language_code": detected_lang_code, "confidence": confidence}
def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE): def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
......
...@@ -22,7 +22,7 @@ There are 2 endpoints available: ...@@ -22,7 +22,7 @@ There are 2 endpoints available:
| Name | Values | | Name | Values |
|-----------------|------------------------------------------------| |-----------------|------------------------------------------------|
| audio_file | File | | audio_file | File |
| output | `text` (default), `json`, `vtt`, `strt`, `tsv` | | output | `text` (default), `json`, `vtt`, `srt`, `tsv` |
| task | `transcribe`, `translate` | | task | `transcribe`, `translate` |
| language | `en` (default is auto recognition) | | language | `en` (default is auto recognition) |
| word_timestamps | false (default) | | word_timestamps | false (default) |
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment