Automatic Speech Recognition
Transformers
Safetensors
English
multilingual
whisper
audio
captioning
audio-captioning
speech
voice
timbre
emotion
BUD-E-Whisper_V1.21 / inference.py
ChristophSchuhmann's picture
Upload BUD-E-Whisper V1.21 — emotion-balanced fine-tune
54bec88 verified
"""
BUD-E-Whisper V1.2 — Audio Captioning Inference
Generates detailed temporal captions from audio files.
Usage:
python inference.py audio.wav
python inference.py audio.mp3 --device cuda
python inference.py audio.flac --max_length 448
"""
import argparse
import torch
import torchaudio
from transformers import WhisperProcessor, WhisperForConditionalGeneration
MODEL_ID = "laion/BUD-E-Whisper_V1.21"
TARGET_SR = 16000
MAX_AUDIO_SECONDS = 30
def load_audio(path: str) -> torch.Tensor:
"""Load audio file, resample to 16kHz mono, cap at 30s."""
wav, sr = torchaudio.load(path)
# Mono
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# Resample
if sr != TARGET_SR:
wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
# Cap duration
max_samples = TARGET_SR * MAX_AUDIO_SECONDS
if wav.shape[1] > max_samples:
wav = wav[:, :max_samples]
return wav.squeeze(0).numpy()
def caption(audio_path: str, device: str = "cuda", max_length: int = 448) -> str:
"""Generate a detailed caption for an audio file."""
processor = WhisperProcessor.from_pretrained(MODEL_ID)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
model.generation_config.forced_decoder_ids = None
model.eval().to(device)
audio = load_audio(audio_path)
inputs = processor.feature_extractor(
audio, sampling_rate=TARGET_SR, return_tensors="pt"
).to(device)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_length=max_length)
return processor.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="BUD-E-Whisper V1.2 Audio Captioning")
parser.add_argument("audio", help="Path to audio file (wav, mp3, flac, etc.)")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--max_length", type=int, default=448)
args = parser.parse_args()
result = caption(args.audio, device=args.device, max_length=args.max_length)
print(result)