Select Git revision
utils.py 4.40 KiB
import json
import os
from dataclasses import asdict
from typing import TextIO, BinaryIO
import ffmpeg
import numpy as np
from faster_whisper.utils import format_timestamp
from app.config import CONFIG
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(self, result: dict, audio_path: str):
audio_basename = os.path.basename(audio_path)
output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f)
def write_result(self, result: dict, file: TextIO):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(self, result: dict, file: TextIO):
for segment in result["segments"]:
print(segment.text.strip(), file=file, flush=True)
class WriteVTT(ResultWriter):
extension: str = "vtt"
def write_result(self, result: dict, file: TextIO):
print("WEBVTT\n", file=file)
for segment in result["segments"]:
print(
f"{format_timestamp(segment.start)} --> {format_timestamp(segment.end)}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteSRT(ResultWriter):
extension: str = "srt"
def write_result(self, result: dict, file: TextIO):
for i, segment in enumerate(result["segments"], start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment.start, always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment.end, always_include_hours=True, decimal_marker=',')}\n"
f"{segment.text.strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(self, result: dict, file: TextIO):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment.start), file=file, end="\t")
print(round(1000 * segment.end), file=file, end="\t")
print(segment.text.strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(self, result: dict, file: TextIO):
if "segments" in result:
result["segments"] = [asdict(segment) for segment in result["segments"]]
json.dump(result, file)
def load_audio(file: BinaryIO, encode=True, sr: int = CONFIG.SAMPLE_RATE):
"""
Open an audio file object and read as mono waveform, resampling as necessary.
Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object
Parameters
----------
file: BinaryIO
The audio file like object
encode: Boolean
If true, encode audio stream to WAV before sending to whisper
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
if encode:
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input("pipe:", threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read())
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
out = file.read()
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0