audio / app.py
bekzod123's picture
fixes
2adef92
import csv
import datetime
import gc
import os
import re
import shutil
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import gradio as gr
import gradio.themes as gr_themes
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download
from nemo.collections.asr.models import ASRModel
from pydub import AudioSegment
try:
from nemo.collections.asr.models import SortformerEncLabelModel
except Exception:
SortformerEncLabelModel = None
try:
import librosa
except Exception:
librosa = None
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "bekzod123/nemo_asr_2"
DIAR_MODEL_NAME = "nvidia/diar_sortformer_4spk-v1"
local_nemo_path = hf_hub_download(
repo_id=MODEL_NAME, filename="nemo_asr_2.nemo", repo_type="model"
)
model = ASRModel.restore_from(restore_path=local_nemo_path, map_location=device)
model.eval()
diar_model = None
def get_diar_model():
global diar_model
if diar_model is not None:
return diar_model
if SortformerEncLabelModel is None:
raise RuntimeError(
"SortformerEncLabelModel not available. Install/upgrade nemo_toolkit[asr]."
)
diar_model = SortformerEncLabelModel.from_pretrained(DIAR_MODEL_NAME)
diar_model.eval()
return diar_model
def start_session(request: gr.Request):
session_hash = request.session_hash
session_dir = Path(f"/tmp/{session_hash}")
session_dir.mkdir(parents=True, exist_ok=True)
print(f"Session with hash {session_hash} started.")
return session_dir.as_posix()
def end_session(request: gr.Request):
session_hash = request.session_hash
session_dir = Path(f"/tmp/{session_hash}")
if session_dir.exists():
shutil.rmtree(session_dir)
print(f"Session with hash {session_hash} ended.")
def _try_float(v):
try:
return float(v)
except Exception:
return None
def get_audio_segment(audio_path, start_second, end_second):
if not audio_path or not Path(audio_path).exists():
print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
return None
try:
start_ms = max(0, int(start_second * 1000))
end_ms = int(end_second * 1000)
if end_ms <= start_ms:
end_ms = start_ms + 100
audio = AudioSegment.from_file(audio_path)
clipped_audio = audio[start_ms:end_ms]
if len(clipped_audio) <= 0:
return None
# Always return float32 [-1, 1] to avoid Gradio int8 overflow path.
samples = np.array(clipped_audio.get_array_of_samples(), dtype=np.float32)
channels = max(1, int(clipped_audio.channels))
if channels > 1:
samples = samples.reshape((-1, channels)).mean(axis=1)
max_abs = float(1 << (8 * clipped_audio.sample_width - 1))
if max_abs <= 0:
max_abs = 32768.0
samples = np.clip(samples / max_abs, -1.0, 1.0).astype(np.float32, copy=False)
frame_rate = int(clipped_audio.frame_rate or audio.frame_rate or 16000)
if samples.size == 0:
return None
return frame_rate, samples
except Exception as e:
print(
f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}"
)
return None
def format_srt_time(seconds: float) -> str:
sanitized_total_seconds = max(0.0, seconds)
delta = datetime.timedelta(seconds=sanitized_total_seconds)
total_int_seconds = int(delta.total_seconds())
hours = total_int_seconds // 3600
remainder_seconds_after_hours = total_int_seconds % 3600
minutes = remainder_seconds_after_hours // 60
seconds_part = remainder_seconds_after_hours % 60
milliseconds = delta.microseconds // 1000
return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"
def generate_srt_content(segment_timestamps: list) -> str:
srt_content = []
for i, ts in enumerate(segment_timestamps):
start_time = format_srt_time(ts["start"])
end_time = format_srt_time(ts["end"])
text = ts.get("segment", "")
speaker = ts.get("speaker", "N/A")
if speaker != "N/A":
text = f"[{speaker}] {text}"
srt_content.append(str(i + 1))
srt_content.append(f"{start_time} --> {end_time}")
srt_content.append(text)
srt_content.append("")
return "\n".join(srt_content)
def _gaussian_kernel(radius: int, sigma: float) -> np.ndarray:
if radius <= 0:
return np.array([1.0], dtype=np.float32)
x = np.arange(-radius, radius + 1, dtype=np.float32)
sigma = max(float(sigma), 1e-6)
kernel = np.exp(-0.5 * (x / sigma) ** 2)
kernel /= np.sum(kernel)
return kernel.astype(np.float32)
def remove_dc_offset(samples: np.ndarray) -> np.ndarray:
samples = np.asarray(samples, dtype=np.float32)
if samples.size == 0:
return samples
return samples - np.mean(samples, dtype=np.float32)
def fft_bandpass(
samples: np.ndarray, sr: int, low_hz: float, high_hz: float
) -> np.ndarray:
samples = np.asarray(samples, dtype=np.float32)
if samples.size == 0:
return samples
low_hz = max(0.0, float(low_hz))
high_hz = min(float(high_hz), sr / 2.0)
if low_hz <= 0 and high_hz >= sr / 2.0:
return samples
spectrum = np.fft.rfft(samples)
freqs = np.fft.rfftfreq(samples.shape[0], d=1.0 / sr)
keep = (freqs >= low_hz) & (freqs <= high_hz)
spectrum[~keep] = 0.0
filtered = np.fft.irfft(spectrum, n=samples.shape[0])
return filtered.astype(np.float32, copy=False)
def spectral_denoise(
samples: np.ndarray,
strength: float = 1.2,
noise_percentile: float = 15.0,
min_mask: float = 0.06,
) -> np.ndarray:
samples = np.asarray(samples, dtype=np.float32)
if samples.size == 0:
return samples
if librosa is None:
return samples
n_fft = 512
hop = 128
stft = librosa.stft(samples, n_fft=n_fft, hop_length=hop, win_length=n_fft)
magnitude = np.abs(stft)
phase = np.angle(stft)
noise_mag = np.percentile(magnitude, noise_percentile, axis=1, keepdims=True)
noise_power = noise_mag * noise_mag
signal_power = magnitude * magnitude
residual_power = np.maximum(signal_power - strength * noise_power, 0.0)
mask = residual_power / (residual_power + strength * noise_power + 1e-8)
mask = np.clip(mask, min_mask, 1.0)
cleaned_stft = magnitude * mask * np.exp(1j * phase)
cleaned = librosa.istft(
cleaned_stft, hop_length=hop, win_length=n_fft, length=len(samples)
)
return cleaned.astype(np.float32, copy=False)
def dynamic_rms_normalize(
samples: np.ndarray,
sample_rate: int,
frame_ms: int = 500,
target_rms_db: float = -20.0,
smoothing_sigma_frames: float = 1.0,
min_gain: float = 0.2,
max_gain: float = 8.0,
) -> np.ndarray:
samples = np.asarray(samples, dtype=np.float32)
if samples.size == 0:
return samples
frame_len = max(1, int(sample_rate * frame_ms / 1000))
hop_len = frame_len
target_rms = 10.0 ** (target_rms_db / 20.0)
n = samples.shape[0]
num_frames = max(1, int(np.ceil(max(0, n - frame_len) / hop_len)) + 1)
rms_values = np.zeros(num_frames, dtype=np.float32)
for i in range(num_frames):
start = i * hop_len
end = min(start + frame_len, n)
frame = samples[start:end]
rms_values[i] = np.sqrt(np.mean(frame * frame) + 1e-12) if frame.size else 1e-6
gains = target_rms / np.maximum(rms_values, 1e-6)
gains = np.clip(gains, min_gain, max_gain)
radius = int(max(1, round(3 * smoothing_sigma_frames)))
kernel = _gaussian_kernel(radius, smoothing_sigma_frames)
padded = np.pad(gains, (radius, radius), mode="edge")
gains_smooth = np.convolve(padded, kernel, mode="valid")
if num_frames == 1:
gain_curve = np.full(n, gains_smooth[0], dtype=np.float32)
else:
centers = np.minimum(np.arange(num_frames) * hop_len + (frame_len // 2), n - 1)
gain_curve = np.interp(
np.arange(n),
centers,
gains_smooth,
left=gains_smooth[0],
right=gains_smooth[-1],
).astype(np.float32)
out = samples * gain_curve
return np.clip(out, -1.0, 1.0).astype(np.float32, copy=False)
def soft_limiter(samples: np.ndarray, drive: float = 1.15) -> np.ndarray:
samples = np.asarray(samples, dtype=np.float32)
if samples.size == 0:
return samples
return np.tanh(samples * drive).astype(np.float32, copy=False)
def preprocess_audio_for_transcription(
audio: AudioSegment,
target_sr: int = 16000,
frame_ms: int = 500,
target_rms_db: float = -20.0,
) -> AudioSegment:
if audio.channels != 1:
audio = audio.set_channels(1)
if audio.frame_rate != target_sr:
audio = audio.set_frame_rate(target_sr)
raw = np.array(audio.get_array_of_samples(), dtype=np.float32)
if raw.size == 0:
raise ValueError("Empty audio data after loading.")
max_abs = float(1 << (8 * audio.sample_width - 1))
if max_abs <= 0:
max_abs = 32768.0
samples = np.clip(raw / max_abs, -1.0, 1.0)
samples = remove_dc_offset(samples)
samples = spectral_denoise(
samples, strength=1.25, noise_percentile=15.0, min_mask=0.06
)
samples = fft_bandpass(samples, sr=target_sr, low_hz=120.0, high_hz=3600.0)
samples = dynamic_rms_normalize(
samples=samples,
sample_rate=target_sr,
frame_ms=frame_ms,
target_rms_db=target_rms_db,
smoothing_sigma_frames=1.0,
min_gain=0.2,
max_gain=8.0,
)
samples = soft_limiter(samples, drive=1.10)
pcm16 = (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16)
return AudioSegment(
data=pcm16.tobytes(),
sample_width=2,
frame_rate=target_sr,
channels=1,
)
def normalize_speaker_label(label) -> str:
txt = str(label).strip()
if not txt:
return "SPEAKER_0"
if txt.isdigit():
return f"SPEAKER_{txt}"
up = txt.upper().replace(" ", "_")
if up.startswith("SPEAKER"):
return up
return up
def _parse_rttm_line(line: str):
parts = line.strip().split()
if len(parts) < 8 or parts[0].upper() != "SPEAKER":
return None
start = _try_float(parts[3])
dur = _try_float(parts[4])
speaker = parts[7]
if start is None or dur is None or dur <= 0:
return None
return {
"start": start,
"end": start + dur,
"speaker": normalize_speaker_label(speaker),
}
def _parse_simple_segment_line(line: str):
# Handles: "0.00, 1.24, 2" or "0.00 1.24 SPEAKER_2"
cleaned = line.strip().replace(",", " ")
parts = [p for p in cleaned.split() if p]
if len(parts) < 3:
return None
start = _try_float(parts[0])
end = _try_float(parts[1])
speaker = parts[2]
if start is None or end is None or end <= start:
return None
return {"start": start, "end": end, "speaker": normalize_speaker_label(speaker)}
def parse_diarization_output(raw_output, audio_duration_sec=None) -> list:
parsed = []
def append_seg(start, end, speaker):
s = _try_float(start)
e = _try_float(end)
if s is None or e is None or e <= s:
return
parsed.append(
{"start": s, "end": e, "speaker": normalize_speaker_label(speaker)}
)
def walk(obj):
if obj is None:
return
if isinstance(obj, Path):
if obj.exists() and obj.suffix.lower() == ".rttm":
with open(obj, "r", encoding="utf-8") as f:
for line in f:
seg = _parse_rttm_line(line) or _parse_simple_segment_line(line)
if seg:
parsed.append(seg)
return
if isinstance(obj, str):
maybe_path = Path(obj)
if maybe_path.exists() and maybe_path.suffix.lower() == ".rttm":
walk(maybe_path)
return
if "\n" in obj:
for line in obj.splitlines():
seg = _parse_rttm_line(line) or _parse_simple_segment_line(line)
if seg:
parsed.append(seg)
return
seg = _parse_rttm_line(obj) or _parse_simple_segment_line(obj)
if seg:
parsed.append(seg)
return
if isinstance(obj, dict):
start = obj.get("start", obj.get("start_time", obj.get("begin")))
end = obj.get("end", obj.get("end_time", obj.get("stop")))
dur = obj.get("duration")
speaker = obj.get("speaker", obj.get("speaker_id", obj.get("label", "0")))
if end is None and start is not None and dur is not None:
s = _try_float(start)
d = _try_float(dur)
if s is not None and d is not None:
end = s + d
if start is not None and end is not None:
append_seg(start, end, speaker)
for v in obj.values():
walk(v)
return
if isinstance(obj, (list, tuple)):
if (
len(obj) >= 3
and _try_float(obj[0]) is not None
and _try_float(obj[1]) is not None
):
append_seg(obj[0], obj[1], obj[2])
return
for item in obj:
walk(item)
return
if hasattr(obj, "start") and hasattr(obj, "end"):
append_seg(
getattr(obj, "start"), getattr(obj, "end"), getattr(obj, "speaker", "0")
)
walk(raw_output)
if parsed and audio_duration_sec:
max_end = max(seg["end"] for seg in parsed)
# Guard for millisecond outputs
if max_end > audio_duration_sec * 20:
for seg in parsed:
seg["start"] /= 1000.0
seg["end"] /= 1000.0
parsed.sort(key=lambda x: (x["start"], x["end"]))
# De-duplicate exact repeats
deduped = []
seen = set()
for seg in parsed:
key = (round(seg["start"], 3), round(seg["end"], 3), seg["speaker"])
if key not in seen:
seen.add(key)
deduped.append(seg)
return deduped
def merge_adjacent_speaker_segments(segments: list, max_gap_sec: float = 0.15) -> list:
if not segments:
return []
merged = [segments[0].copy()]
for seg in segments[1:]:
last = merged[-1]
if (
seg["speaker"] == last["speaker"]
and seg["start"] - last["end"] <= max_gap_sec
):
last["end"] = max(last["end"], seg["end"])
else:
merged.append(seg.copy())
return merged
def merge_consecutive_transcript_rows(rows: list) -> list:
if not rows:
return []
merged = [rows[0].copy()]
for row in rows[1:]:
last = merged[-1]
if row.get("speaker") == last.get("speaker"):
last["end"] = max(float(last["end"]), float(row["end"]))
prev_text = (last.get("segment") or "").strip()
cur_text = (row.get("segment") or "").strip()
if prev_text and cur_text:
last["segment"] = f"{prev_text} {cur_text}"
elif cur_text:
last["segment"] = cur_text
else:
merged.append(
{
"start": float(row["start"]),
"end": float(row["end"]),
"speaker": row.get("speaker", "N/A"),
"segment": (row.get("segment") or "").strip(),
}
)
return merged
def transcribe_with_segments_and_words(transcribe_path: str):
output = model.transcribe([transcribe_path], timestamps=True)
if (
not output
or not isinstance(output, list)
or not output[0]
or not hasattr(output[0], "timestamp")
or not output[0].timestamp
or "segment" not in output[0].timestamp
):
raise RuntimeError("Transcription failed or unexpected output format.")
timestamp_payload = output[0].timestamp
segments = []
for ts in timestamp_payload.get("segment", []):
start = _try_float(ts.get("start"))
end = _try_float(ts.get("end"))
text = str(ts.get("segment", ts.get("text", ""))).strip()
if start is None or end is None or end <= start:
continue
segments.append(
{
"start": float(start),
"end": float(end),
"speaker": "N/A",
"segment": text,
}
)
words = []
for w in timestamp_payload.get("word", []):
if isinstance(w, dict):
start = _try_float(w.get("start", w.get("start_time", w.get("begin"))))
end = _try_float(w.get("end", w.get("end_time", w.get("stop"))))
token = str(w.get("word", w.get("token", w.get("text", "")))).strip()
elif isinstance(w, (list, tuple)) and len(w) >= 3:
start = _try_float(w[0])
end = _try_float(w[1])
token = str(w[2]).strip()
else:
start = _try_float(getattr(w, "start", None))
end = _try_float(getattr(w, "end", None))
token = str(getattr(w, "word", getattr(w, "text", ""))).strip()
if start is None or end is None or end <= start:
continue
words.append({"start": float(start), "end": float(end), "token": token})
return segments, words
def transcribe_default_with_timestamps(transcribe_path: str):
segments, _ = transcribe_with_segments_and_words(transcribe_path)
return segments
def _overlap_seconds(
a_start: float, a_end: float, b_start: float, b_end: float
) -> float:
return max(0.0, min(a_end, b_end) - max(a_start, b_start))
def _join_tokens(tokens: list) -> str:
return " ".join(t for t in tokens if t).strip()
def split_asr_by_diarization_segments(
asr_segments: list, diar_segments: list, asr_words: list = None
) -> list:
if not diar_segments:
return []
diar_segments = sorted(diar_segments, key=lambda x: (x["start"], x["end"]))
# Preferred: word-level mapping to each diar segment.
if asr_words:
words = sorted(asr_words, key=lambda x: (x["start"], x["end"]))
rows_word = []
word_idx = 0
for d in diar_segments:
d_start = float(d["start"])
d_end = float(d["end"])
while word_idx < len(words) and words[word_idx]["end"] <= d_start:
word_idx += 1
scan = word_idx
tokens = []
while scan < len(words) and words[scan]["start"] < d_end:
w = words[scan]
if _overlap_seconds(d_start, d_end, w["start"], w["end"]) > 0:
tokens.append(w["token"])
scan += 1
rows_word.append(
{
"start": d_start,
"end": d_end,
"speaker": d["speaker"],
"segment": _join_tokens(tokens),
}
)
if any((r.get("segment") or "").strip() for r in rows_word):
return rows_word
# Fallback: segment-level overlap assignment.
buckets = [[] for _ in diar_segments]
for s in asr_segments:
s_start = float(s["start"])
s_end = float(s["end"])
txt = (s.get("segment") or "").strip()
if not txt:
continue
best_i = -1
best_ov = 0.0
for i, d in enumerate(diar_segments):
ov = _overlap_seconds(s_start, s_end, float(d["start"]), float(d["end"]))
if ov > best_ov:
best_ov = ov
best_i = i
if best_i >= 0:
buckets[best_i].append(txt)
rows = []
for i, d in enumerate(diar_segments):
rows.append(
{
"start": float(d["start"]),
"end": float(d["end"]),
"speaker": d["speaker"],
"segment": " ".join(buckets[i]).strip(),
}
)
return rows
def _clean_token_spacing(text: str) -> str:
text = re.sub(r"\s+([.,!?;:])", r"\1", text)
return re.sub(r"\s+", " ", text).strip()
def _capitalize_first_alpha(text: str) -> str:
return re.sub(
r"^([^A-Za-z]*)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text
)
def _capitalize_after_full_stop(text: str) -> str:
# Capitalize first latin letter after ". "
return re.sub(r"(?<=\.\s)([a-z])", lambda m: m.group(1).upper(), text)
UZ_CARDINAL = {
"nol": 0,
"bir": 1,
"ikki": 2,
"uch": 3,
"to'rt": 4,
"tort": 4,
"besh": 5,
"olti": 6,
"yetti": 7,
"sakkiz": 8,
"to'qqiz": 9,
"toqqiz": 9,
"o'n": 10,
"on": 10,
"yigirma": 20,
"o'ttiz": 30,
"ottiz": 30,
"qirq": 40,
"ellik": 50,
"oltmish": 60,
"yetmish": 70,
"sakson": 80,
"to'qson": 90,
"toqson": 90,
"o'nbir": 11,
"onbir": 11,
"o'nikki": 12,
"onikki": 12,
"o'nuch": 13,
"onuch": 13,
"o'nto'rt": 14,
"ontort": 14,
"o'nbesh": 15,
"onbesh": 15,
"o'nolti": 16,
"onolti": 16,
"o'nyetti": 17,
"onyetti": 17,
"o'nsakkiz": 18,
"onsakkiz": 18,
"o'nto'qqiz": 19,
"ontoqqiz": 19,
}
UZ_SCALES = {"ming": 1000, "million": 1_000_000, "milliard": 1_000_000_000}
UZ_ORDINAL_TO_CARDINAL = {
"birinchi": "bir",
"ikkinchi": "ikki",
"uchinchi": "uch",
"to'rtinchi": "to'rt",
"tortinchi": "to'rt",
"beshinchi": "besh",
"oltinchi": "olti",
"yettinchi": "yetti",
"sakkizinchi": "sakkiz",
"to'qqizinchi": "to'qqiz",
"toqqizinchi": "to'qqiz",
"o'ninchi": "o'n",
"oninchi": "o'n",
}
UZ_MONTHS_PATTERN = (
r"yanvar|fevral|mart|aprel|may|iyun|iyul|avgust|sentabr|oktabr|noyabr|dekabr"
)
_TOKEN_CORE_RE = re.compile(
r"^([^A-Za-z0-9'`ʻʼ’‘]*)([A-Za-z0-9'`ʻʼ’‘]+)([^A-Za-z0-9'`ʻʼ’‘]*)$"
)
def _normalize_uz_word(word: str) -> str:
w = str(word).lower()
w = (
w.replace("’", "'")
.replace("‘", "'")
.replace("`", "'")
.replace("ʻ", "'")
.replace("ʼ", "'")
)
repl = {
"tort": "to'rt",
"toqqiz": "to'qqiz",
"on": "o'n",
"ottiz": "o'ttiz",
"toqson": "to'qson",
}
return repl.get(w, w)
def _is_uz_number_like(word: str) -> bool:
if not word:
return False
if (
word in UZ_CARDINAL
or word in UZ_SCALES
or word == "yuz"
or word in UZ_ORDINAL_TO_CARDINAL
):
return True
return re.match(r"^.+(?:inchi|nchi)$", word) is not None
def _split_token(token: str):
m = _TOKEN_CORE_RE.match(token)
if not m:
return "", "", token
prefix, core, suffix = m.group(1), m.group(2), m.group(3)
core_norm = _normalize_uz_word(core)
# Handle attached clitics on number words, e.g. "yuzmi" -> "yuz" + "mi".
clitic = ""
for c in ("mi",):
if core_norm.endswith(c):
stem = core_norm[: -len(c)]
if _is_uz_number_like(stem):
core_norm = stem
clitic = c
break
return prefix, core_norm, f"{clitic}{suffix}"
def _parse_uz_cardinal(words):
total = 0
current = 0
seen = False
for raw in words:
w = _normalize_uz_word(raw)
if w == "va":
continue
if w in UZ_CARDINAL:
current += UZ_CARDINAL[w]
seen = True
elif w == "yuz":
current = (current or 1) * 100
seen = True
elif w in UZ_SCALES:
scale = UZ_SCALES[w]
if current == 0:
current = 1
total += current * scale
current = 0
seen = True
else:
return None
return (total + current) if seen else None
def _parse_uz_ordinal(words):
if not words:
return None
normalized = [_normalize_uz_word(w) for w in words]
last = normalized[-1]
if last in UZ_ORDINAL_TO_CARDINAL:
base = normalized[:-1] + [UZ_ORDINAL_TO_CARDINAL[last]]
return _parse_uz_cardinal(base)
m = re.match(r"^(.+?)(?:inchi|nchi)$", last)
if m:
stem = _normalize_uz_word(m.group(1))
base = normalized[:-1] + [stem]
return _parse_uz_cardinal(base)
return None
def normalize_uzbek_numbers_in_text(text: str) -> str:
if not text:
return text
tokens = text.split(" ")
out = []
i = 0
n = len(tokens)
while i < n:
p, core, s = _split_token(tokens[i])
if not core:
out.append(tokens[i])
i += 1
continue
best_kind = None
best_val = None
best_end = -1
words = []
j = i
while j < n:
pj, cj, sj = _split_token(tokens[j])
if not cj:
break
if j > i:
_, _, prev_suffix = _split_token(tokens[j - 1])
if prev_suffix:
break
words.append(cj)
card = _parse_uz_cardinal(words)
if card is not None:
best_kind = "card"
best_val = card
best_end = j
ordv = _parse_uz_ordinal(words)
if ordv is not None:
best_kind = "ord"
best_val = ordv
best_end = j
if sj:
break
j += 1
if best_end < i:
out.append(tokens[i])
i += 1
continue
first_prefix, _, _ = _split_token(tokens[i])
_, _, last_suffix = _split_token(tokens[best_end])
repl = str(best_val) if best_kind == "card" else f"{best_val}-chi"
out.append(f"{first_prefix}{repl}{last_suffix}")
i = best_end + 1
return " ".join(out)
def normalize_uzbek_date_forms(text: str) -> str:
# 5 may -> 5-may, 5-chi may -> 5-may
text = re.sub(
rf"\b(\d+)(?:-chi)?\s+({UZ_MONTHS_PATTERN})\b",
lambda m: f"{m.group(1)}-{m.group(2)}",
text,
flags=re.IGNORECASE,
)
# 2024 chi yil / 2024-chi yil / 2024 yil -> 2024-yil
text = re.sub(r"\b(\d+)\s*-\s*chi\s+yil\b", r"\1-yil", text, flags=re.IGNORECASE)
text = re.sub(r"\b(\d+)\s+chi\s+yil\b", r"\1-yil", text, flags=re.IGNORECASE)
text = re.sub(r"\b(\d{3,4})\s+yil\b", r"\1-yil", text, flags=re.IGNORECASE)
return text
def postprocess_segment_texts(
segment_timestamps: list, diarization_enabled: bool
) -> list:
for ts in segment_timestamps:
txt = str(ts.get("segment", "") or "")
txt = _clean_token_spacing(txt)
txt = normalize_uzbek_numbers_in_text(txt)
txt = normalize_uzbek_date_forms(txt)
if diarization_enabled:
txt = _capitalize_first_alpha(txt) # each speaker row starts capitalized
txt = _capitalize_after_full_stop(txt) # after ". " next letter is uppercase
ts["segment"] = txt
return segment_timestamps
def resolve_player_audio_path(prepared_path, fallback_path: str) -> str:
try:
if prepared_path and Path(prepared_path).exists():
return Path(prepared_path).as_posix()
except Exception:
pass
return fallback_path
@spaces.GPU
def get_transcripts_and_raw_times(
audio_path, session_dir, use_preprocessing=True, use_diarization=False
):
if not audio_path:
gr.Error("No audio file path provided for transcription.", duration=None)
return (
[],
[],
None,
gr.DownloadButton(label="Download Transcript (CSV)", visible=False),
gr.DownloadButton(label="Download Transcript (SRT)", visible=False),
)
vis_data = [["N/A", "N/A", "N/A", "Processing failed"]]
raw_times_data = [[0.0, 0.0]]
processed_audio_path = None
diar_audio_path = None
playback_audio_path = None
original_path_name = Path(audio_path).name
audio_name = Path(audio_path).stem
csv_button_update = gr.DownloadButton(
label="Download Transcript (CSV)", visible=False
)
srt_button_update = gr.DownloadButton(
label="Download Transcript (SRT)", visible=False
)
transcribe_path = audio_path
info_path_name = original_path_name
try:
gr.Info(f"Loading audio: {original_path_name}", duration=2)
audio = AudioSegment.from_file(audio_path)
duration_sec = audio.duration_seconds
# Keep stable playback source in session dir.
try:
playback_audio_path = Path(session_dir, f"{audio_name}_playback.wav")
audio.export(playback_audio_path, format="wav")
except Exception as playback_e:
playback_audio_path = None
gr.Warning(f"Could not prepare playback audio: {playback_e}", duration=5)
if use_preprocessing:
try:
gr.Info(
"Preprocessing enabled: mono + denoise + phone-band + dynamic RMS + 16kHz...",
duration=3,
)
processed_audio = preprocess_audio_for_transcription(
audio=audio, target_sr=16000, frame_ms=500, target_rms_db=-20.0
)
processed_audio_path = Path(
session_dir, f"{audio_name}_asr_preprocessed.wav"
)
processed_audio.export(processed_audio_path, format="wav")
transcribe_path = processed_audio_path.as_posix()
info_path_name = f"{original_path_name} (preprocessed)"
except Exception as preprocess_e:
gr.Warning(
f"Preprocessing failed ({preprocess_e}). Falling back to original audio.",
duration=6,
)
transcribe_path = audio_path
info_path_name = original_path_name
else:
gr.Info("Preprocessing disabled. Using original audio.", duration=2)
long_audio_settings_applied = False
try:
model.to(device)
model.to(torch.float32)
gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
if duration_sec > 480:
try:
gr.Info(
"Audio longer than 8 minutes. Applying long audio settings.",
duration=3,
)
model.change_attention_model("rel_pos_local_attn", [256, 256])
model.change_subsampling_conv_chunking_factor(1)
long_audio_settings_applied = True
except Exception as setting_e:
gr.Warning(
f"Could not apply long audio settings: {setting_e}", duration=5
)
if device == "cuda":
model.to(torch.bfloat16)
segment_timestamps = []
if use_diarization:
try:
gr.Info("Running ASR and diarization in parallel...", duration=3)
diar_input_path = audio_path
dmodel = get_diar_model()
dmodel.to(device)
dmodel.to(torch.float32)
def _run_asr():
return transcribe_with_segments_and_words(transcribe_path)
def _run_diar():
try:
diar_output_local = dmodel.diarize(
audio=diar_input_path, batch_size=1
)
except TypeError:
diar_output_local = dmodel.diarize(
audio=[diar_input_path], batch_size=1
)
diar_segments_local = parse_diarization_output(
diar_output_local,
audio_duration_sec=duration_sec,
)
diar_segments_local = merge_adjacent_speaker_segments(
diar_segments_local, max_gap_sec=0.15
)
return diar_segments_local, diar_output_local
with ThreadPoolExecutor(max_workers=2) as pool:
asr_future = pool.submit(_run_asr)
diar_future = pool.submit(_run_diar)
asr_segments, asr_words = asr_future.result()
diar_segments, diar_output = diar_future.result()
if not diar_segments:
gr.Warning(
f"Diarization parsed no segments. Using ASR segmentation. raw_type={type(diar_output)}",
duration=7,
)
segment_timestamps = asr_segments
else:
# Split text by diarization segments, then merge consecutive same-speaker rows.
segment_timestamps = split_asr_by_diarization_segments(
asr_segments=asr_segments,
diar_segments=diar_segments,
asr_words=asr_words,
)
segment_timestamps = merge_consecutive_transcript_rows(
segment_timestamps
)
if not segment_timestamps:
gr.Warning(
"No aligned diarized rows. Using ASR segmentation.",
duration=7,
)
segment_timestamps = asr_segments
gr.Info("Diarization + ASR complete.", duration=2)
except Exception as diar_e:
gr.Warning(
f"Diarization failed: {diar_e}. Using standard ASR segmentation.",
duration=7,
)
segment_timestamps = transcribe_default_with_timestamps(
transcribe_path
)
else:
segment_timestamps = transcribe_default_with_timestamps(transcribe_path)
segment_timestamps = postprocess_segment_texts(
segment_timestamps,
diarization_enabled=use_diarization,
)
vis_data = [
[
round(float(ts["start"]), 2),
round(float(ts["end"]), 2),
ts.get("speaker", "N/A"),
ts.get("segment", ""),
]
for ts in segment_timestamps
]
raw_times_data = [
[float(ts["start"]), float(ts["end"])] for ts in segment_timestamps
]
try:
csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
with open(csv_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["Start (s)", "End (s)", "Speaker", "Segment"])
writer.writerows(vis_data)
csv_button_update = gr.DownloadButton(
value=csv_file_path, visible=True, label="Download Transcript (CSV)"
)
except Exception as csv_e:
gr.Error(
f"Failed to create transcript CSV file: {csv_e}", duration=None
)
if segment_timestamps:
try:
srt_content = generate_srt_content(segment_timestamps)
srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt")
with open(srt_file_path, "w", encoding="utf-8") as f:
f.write(srt_content)
srt_button_update = gr.DownloadButton(
value=srt_file_path,
visible=True,
label="Download Transcript (SRT)",
)
except Exception as srt_e:
gr.Warning(
f"Failed to create transcript SRT file: {srt_e}", duration=5
)
gr.Info("Transcription complete.", duration=2)
return (
vis_data,
raw_times_data,
resolve_player_audio_path(playback_audio_path, audio_path),
csv_button_update,
srt_button_update,
)
except torch.cuda.OutOfMemoryError:
error_msg = "CUDA out of memory. Try shorter audio or reduce GPU load."
gr.Error(error_msg, duration=None)
return (
[["OOM", "OOM", "N/A", error_msg]],
[[0.0, 0.0]],
resolve_player_audio_path(playback_audio_path, audio_path),
csv_button_update,
srt_button_update,
)
except FileNotFoundError:
gr.Error(
f"Audio file not found for transcription: {Path(transcribe_path).name}",
duration=None,
)
return (
[["Error", "Error", "N/A", "File not found for transcription"]],
[[0.0, 0.0]],
resolve_player_audio_path(playback_audio_path, audio_path),
csv_button_update,
srt_button_update,
)
except Exception as e:
gr.Error(f"Transcription failed: {e}", duration=None)
return (
[["Error", "Error", "N/A", f"Transcription failed: {e}"]],
[[0.0, 0.0]],
resolve_player_audio_path(playback_audio_path, audio_path),
csv_button_update,
srt_button_update,
)
finally:
try:
if long_audio_settings_applied:
try:
model.change_attention_model("rel_pos")
model.change_subsampling_conv_chunking_factor(-1)
except Exception as revert_e:
gr.Warning(
f"Issue reverting model settings: {revert_e}", duration=5
)
if device == "cuda":
model.cpu()
if diar_model is not None:
diar_model.cpu()
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
except Exception as cleanup_e:
gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
finally:
for tmp_path in [processed_audio_path, diar_audio_path]:
if tmp_path and os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except Exception as e:
print(f"Error removing temporary audio file {tmp_path}: {e}")
def play_segment(raw_ts_list, current_audio_path, evt: gr.SelectData):
if not isinstance(raw_ts_list, list) or not current_audio_path:
return gr.update(value=None, label="Selected Segment")
if evt is None or evt.index is None:
return gr.update(value=None, label="Selected Segment")
if isinstance(evt.index, (list, tuple)):
if not evt.index:
return gr.update(value=None, label="Selected Segment")
selected_index = int(evt.index[0])
else:
selected_index = int(evt.index)
if selected_index < 0 or selected_index >= len(raw_ts_list):
return gr.update(value=None, label="Selected Segment")
selected_row = raw_ts_list[selected_index]
if not isinstance(selected_row, (list, tuple)) or len(selected_row) != 2:
return gr.update(value=None, label="Selected Segment")
start_time_s = _try_float(selected_row[0])
end_time_s = _try_float(selected_row[1])
if start_time_s is None or end_time_s is None or end_time_s <= start_time_s:
return gr.update(value=None, label="Selected Segment")
segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)
if segment_data:
return gr.update(
value=segment_data,
autoplay=True,
label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s",
)
return gr.update(value=None, label="Selected Segment")
article = (
"<p style='font-size:1.1em;'>Optional preprocessing and optional speaker diarization are supported.</p>"
"<ul style='font-size:1.1em;'>"
"<li>Preprocessing (optional): mono, denoise, bandpass, RMS normalize, 16kHz</li>"
"<li>Diarization (optional): nvidia/diar_sortformer_4spk-v1</li>"
"<li>ASR and diarization run in parallel when diarization is enabled</li>"
"<li>Rows are split by diarization segments; consecutive same-speaker rows are merged</li>"
"<li>Post-processing: sentence capitalization + Uzbek number/date normalization</li>"
"</ul>"
)
examples = [["data/example-yt_saTD1u8PorI.mp3"]]
nvidia_theme = gr_themes.Default(
primary_hue=gr_themes.Color(
c50="#E6ECF7",
c100="#CCD9EF",
c200="#99B3DF",
c300="#668DCC",
c400="#3366B3",
c500="#003399",
c600="#002E8A",
c700="#00246D",
c800="#001A51",
c900="#001238",
c950="#000B24",
),
neutral_hue="gray",
font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
).set()
with gr.Blocks(theme=nvidia_theme) as demo:
model_display_name = MODEL_NAME.split("/")[-1] if "/" in MODEL_NAME else MODEL_NAME
gr.Markdown(
f"<h1 style='text-align:center;margin:0 auto;'>Speech Transcription with {model_display_name}</h1>"
)
gr.HTML(article)
current_audio_path_state = gr.State(None)
raw_timestamps_list_state = gr.State([])
session_dir = gr.State()
demo.load(start_session, outputs=[session_dir])
with gr.Row():
use_preprocessing = gr.Checkbox(label="Enable preprocessing", value=True)
use_diarization = gr.Checkbox(
label="Enable speaker diarization (nvidia/diar_sortformer_4spk-v1)",
value=False,
)
with gr.Tabs():
with gr.TabItem("Audio File"):
file_input = gr.Audio(
sources=["upload"], type="filepath", label="Upload Audio File"
)
gr.Examples(
examples=examples, inputs=[file_input], label="Example Audio Files"
)
file_transcribe_btn = gr.Button(
"Transcribe Uploaded File", variant="primary"
)
with gr.TabItem("Microphone"):
mic_input = gr.Audio(
sources=["microphone"], type="filepath", label="Record Audio"
)
mic_transcribe_btn = gr.Button(
"Transcribe Microphone Input", variant="primary"
)
gr.Markdown("---")
with gr.Row():
download_btn_csv = gr.DownloadButton(
label="Download Transcript (CSV)", visible=False
)
download_btn_srt = gr.DownloadButton(
label="Download Transcript (SRT)", visible=False
)
vis_timestamps_df = gr.DataFrame(
headers=["Start (s)", "End (s)", "Speaker", "Segment"],
datatype=["number", "number", "str", "str"],
wrap=True,
label="Transcription Segments",
)
selected_segment_player = gr.Audio(label="Selected Segment", interactive=False)
mic_transcribe_btn.click(
fn=get_transcripts_and_raw_times,
inputs=[mic_input, session_dir, use_preprocessing, use_diarization],
outputs=[
vis_timestamps_df,
raw_timestamps_list_state,
current_audio_path_state,
download_btn_csv,
download_btn_srt,
],
api_name="transcribe_mic",
)
file_transcribe_btn.click(
fn=get_transcripts_and_raw_times,
inputs=[file_input, session_dir, use_preprocessing, use_diarization],
outputs=[
vis_timestamps_df,
raw_timestamps_list_state,
current_audio_path_state,
download_btn_csv,
download_btn_srt,
],
api_name="transcribe_file",
)
vis_timestamps_df.select(
fn=play_segment,
inputs=[raw_timestamps_list_state, current_audio_path_state],
outputs=[selected_segment_player],
)
demo.unload(end_session)
if __name__ == "__main__":
print("Launching Gradio Demo...")
demo.queue()
demo.launch()