Speechbrain_SPSoft / handler.py
admin-spsoft's picture
Set custom endpoint handler sample rate to 16 kHz
dac79cf verified
import base64
import io
from typing import Any
import numpy as np
import soundfile as sf
import torch
import torchaudio
# SpeechBrain 1.0.x still expects this legacy torchaudio helper.
if not hasattr(torchaudio, "list_audio_backends"):
torchaudio.list_audio_backends = lambda: ["soundfile"]
from speechbrain.inference.separation import SepformerSeparation
TARGET_SAMPLE_RATE = 16000
class EndpointHandler:
def __init__(self, path: str = ""):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SepformerSeparation.from_hparams(
source=path or ".",
savedir=path or ".",
run_opts={"device": device},
)
def __call__(self, data: Any) -> dict:
audio_bytes = self._extract_audio_bytes(data)
waveform, sample_rate = self._load_audio(audio_bytes)
with torch.no_grad():
est_sources = self.model.separate_batch(waveform.unsqueeze(0))
est_sources = est_sources.squeeze(0).detach().cpu()
if est_sources.ndim == 1:
est_sources = est_sources.unsqueeze(-1)
outputs = []
for idx in range(est_sources.shape[-1]):
source = est_sources[:, idx].numpy()
buffer = io.BytesIO()
sf.write(buffer, source, TARGET_SAMPLE_RATE, format="WAV")
outputs.append(
{
"speaker": idx,
"audio_base64": base64.b64encode(buffer.getvalue()).decode("utf-8"),
"sample_rate": TARGET_SAMPLE_RATE,
"mime_type": "audio/wav",
}
)
return {
"num_speakers": len(outputs),
"sources": outputs,
}
def _extract_audio_bytes(self, data: Any) -> bytes:
if isinstance(data, (bytes, bytearray)):
return bytes(data)
if isinstance(data, dict):
payload = data.get("inputs", data)
if isinstance(payload, (bytes, bytearray)):
return bytes(payload)
if isinstance(payload, str):
return self._decode_base64_audio(payload)
if isinstance(payload, dict):
for key in ("audio", "audio_base64", "data"):
value = payload.get(key)
if isinstance(value, str):
return self._decode_base64_audio(value)
raise ValueError("Unsupported request format. Send raw audio bytes or a JSON body with base64 audio.")
def _decode_base64_audio(self, value: str) -> bytes:
if "," in value and value.startswith("data:"):
value = value.split(",", 1)[1]
return base64.b64decode(value)
def _load_audio(self, audio_bytes: bytes) -> tuple[torch.Tensor, int]:
waveform, sample_rate = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
waveform = torch.from_numpy(waveform.T)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sample_rate != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLE_RATE)
waveform = resampler(waveform)
return waveform.squeeze(0), TARGET_SAMPLE_RATE