import csv import datetime import gc import os import re import shutil from concurrent.futures import ThreadPoolExecutor from pathlib import Path import gradio as gr import gradio.themes as gr_themes import numpy as np import spaces import torch from huggingface_hub import hf_hub_download from nemo.collections.asr.models import ASRModel from pydub import AudioSegment try: from nemo.collections.asr.models import SortformerEncLabelModel except Exception: SortformerEncLabelModel = None try: import librosa except Exception: librosa = None device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "bekzod123/nemo_asr_2" DIAR_MODEL_NAME = "nvidia/diar_sortformer_4spk-v1" local_nemo_path = hf_hub_download( repo_id=MODEL_NAME, filename="nemo_asr_2.nemo", repo_type="model" ) model = ASRModel.restore_from(restore_path=local_nemo_path, map_location=device) model.eval() diar_model = None def get_diar_model(): global diar_model if diar_model is not None: return diar_model if SortformerEncLabelModel is None: raise RuntimeError( "SortformerEncLabelModel not available. Install/upgrade nemo_toolkit[asr]." ) diar_model = SortformerEncLabelModel.from_pretrained(DIAR_MODEL_NAME) diar_model.eval() return diar_model def start_session(request: gr.Request): session_hash = request.session_hash session_dir = Path(f"/tmp/{session_hash}") session_dir.mkdir(parents=True, exist_ok=True) print(f"Session with hash {session_hash} started.") return session_dir.as_posix() def end_session(request: gr.Request): session_hash = request.session_hash session_dir = Path(f"/tmp/{session_hash}") if session_dir.exists(): shutil.rmtree(session_dir) print(f"Session with hash {session_hash} ended.") def _try_float(v): try: return float(v) except Exception: return None def get_audio_segment(audio_path, start_second, end_second): if not audio_path or not Path(audio_path).exists(): print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.") return None try: start_ms = max(0, int(start_second * 1000)) end_ms = int(end_second * 1000) if end_ms <= start_ms: end_ms = start_ms + 100 audio = AudioSegment.from_file(audio_path) clipped_audio = audio[start_ms:end_ms] if len(clipped_audio) <= 0: return None # Always return float32 [-1, 1] to avoid Gradio int8 overflow path. samples = np.array(clipped_audio.get_array_of_samples(), dtype=np.float32) channels = max(1, int(clipped_audio.channels)) if channels > 1: samples = samples.reshape((-1, channels)).mean(axis=1) max_abs = float(1 << (8 * clipped_audio.sample_width - 1)) if max_abs <= 0: max_abs = 32768.0 samples = np.clip(samples / max_abs, -1.0, 1.0).astype(np.float32, copy=False) frame_rate = int(clipped_audio.frame_rate or audio.frame_rate or 16000) if samples.size == 0: return None return frame_rate, samples except Exception as e: print( f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}" ) return None def format_srt_time(seconds: float) -> str: sanitized_total_seconds = max(0.0, seconds) delta = datetime.timedelta(seconds=sanitized_total_seconds) total_int_seconds = int(delta.total_seconds()) hours = total_int_seconds // 3600 remainder_seconds_after_hours = total_int_seconds % 3600 minutes = remainder_seconds_after_hours // 60 seconds_part = remainder_seconds_after_hours % 60 milliseconds = delta.microseconds // 1000 return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}" def generate_srt_content(segment_timestamps: list) -> str: srt_content = [] for i, ts in enumerate(segment_timestamps): start_time = format_srt_time(ts["start"]) end_time = format_srt_time(ts["end"]) text = ts.get("segment", "") speaker = ts.get("speaker", "N/A") if speaker != "N/A": text = f"[{speaker}] {text}" srt_content.append(str(i + 1)) srt_content.append(f"{start_time} --> {end_time}") srt_content.append(text) srt_content.append("") return "\n".join(srt_content) def _gaussian_kernel(radius: int, sigma: float) -> np.ndarray: if radius <= 0: return np.array([1.0], dtype=np.float32) x = np.arange(-radius, radius + 1, dtype=np.float32) sigma = max(float(sigma), 1e-6) kernel = np.exp(-0.5 * (x / sigma) ** 2) kernel /= np.sum(kernel) return kernel.astype(np.float32) def remove_dc_offset(samples: np.ndarray) -> np.ndarray: samples = np.asarray(samples, dtype=np.float32) if samples.size == 0: return samples return samples - np.mean(samples, dtype=np.float32) def fft_bandpass( samples: np.ndarray, sr: int, low_hz: float, high_hz: float ) -> np.ndarray: samples = np.asarray(samples, dtype=np.float32) if samples.size == 0: return samples low_hz = max(0.0, float(low_hz)) high_hz = min(float(high_hz), sr / 2.0) if low_hz <= 0 and high_hz >= sr / 2.0: return samples spectrum = np.fft.rfft(samples) freqs = np.fft.rfftfreq(samples.shape[0], d=1.0 / sr) keep = (freqs >= low_hz) & (freqs <= high_hz) spectrum[~keep] = 0.0 filtered = np.fft.irfft(spectrum, n=samples.shape[0]) return filtered.astype(np.float32, copy=False) def spectral_denoise( samples: np.ndarray, strength: float = 1.2, noise_percentile: float = 15.0, min_mask: float = 0.06, ) -> np.ndarray: samples = np.asarray(samples, dtype=np.float32) if samples.size == 0: return samples if librosa is None: return samples n_fft = 512 hop = 128 stft = librosa.stft(samples, n_fft=n_fft, hop_length=hop, win_length=n_fft) magnitude = np.abs(stft) phase = np.angle(stft) noise_mag = np.percentile(magnitude, noise_percentile, axis=1, keepdims=True) noise_power = noise_mag * noise_mag signal_power = magnitude * magnitude residual_power = np.maximum(signal_power - strength * noise_power, 0.0) mask = residual_power / (residual_power + strength * noise_power + 1e-8) mask = np.clip(mask, min_mask, 1.0) cleaned_stft = magnitude * mask * np.exp(1j * phase) cleaned = librosa.istft( cleaned_stft, hop_length=hop, win_length=n_fft, length=len(samples) ) return cleaned.astype(np.float32, copy=False) def dynamic_rms_normalize( samples: np.ndarray, sample_rate: int, frame_ms: int = 500, target_rms_db: float = -20.0, smoothing_sigma_frames: float = 1.0, min_gain: float = 0.2, max_gain: float = 8.0, ) -> np.ndarray: samples = np.asarray(samples, dtype=np.float32) if samples.size == 0: return samples frame_len = max(1, int(sample_rate * frame_ms / 1000)) hop_len = frame_len target_rms = 10.0 ** (target_rms_db / 20.0) n = samples.shape[0] num_frames = max(1, int(np.ceil(max(0, n - frame_len) / hop_len)) + 1) rms_values = np.zeros(num_frames, dtype=np.float32) for i in range(num_frames): start = i * hop_len end = min(start + frame_len, n) frame = samples[start:end] rms_values[i] = np.sqrt(np.mean(frame * frame) + 1e-12) if frame.size else 1e-6 gains = target_rms / np.maximum(rms_values, 1e-6) gains = np.clip(gains, min_gain, max_gain) radius = int(max(1, round(3 * smoothing_sigma_frames))) kernel = _gaussian_kernel(radius, smoothing_sigma_frames) padded = np.pad(gains, (radius, radius), mode="edge") gains_smooth = np.convolve(padded, kernel, mode="valid") if num_frames == 1: gain_curve = np.full(n, gains_smooth[0], dtype=np.float32) else: centers = np.minimum(np.arange(num_frames) * hop_len + (frame_len // 2), n - 1) gain_curve = np.interp( np.arange(n), centers, gains_smooth, left=gains_smooth[0], right=gains_smooth[-1], ).astype(np.float32) out = samples * gain_curve return np.clip(out, -1.0, 1.0).astype(np.float32, copy=False) def soft_limiter(samples: np.ndarray, drive: float = 1.15) -> np.ndarray: samples = np.asarray(samples, dtype=np.float32) if samples.size == 0: return samples return np.tanh(samples * drive).astype(np.float32, copy=False) def preprocess_audio_for_transcription( audio: AudioSegment, target_sr: int = 16000, frame_ms: int = 500, target_rms_db: float = -20.0, ) -> AudioSegment: if audio.channels != 1: audio = audio.set_channels(1) if audio.frame_rate != target_sr: audio = audio.set_frame_rate(target_sr) raw = np.array(audio.get_array_of_samples(), dtype=np.float32) if raw.size == 0: raise ValueError("Empty audio data after loading.") max_abs = float(1 << (8 * audio.sample_width - 1)) if max_abs <= 0: max_abs = 32768.0 samples = np.clip(raw / max_abs, -1.0, 1.0) samples = remove_dc_offset(samples) samples = spectral_denoise( samples, strength=1.25, noise_percentile=15.0, min_mask=0.06 ) samples = fft_bandpass(samples, sr=target_sr, low_hz=120.0, high_hz=3600.0) samples = dynamic_rms_normalize( samples=samples, sample_rate=target_sr, frame_ms=frame_ms, target_rms_db=target_rms_db, smoothing_sigma_frames=1.0, min_gain=0.2, max_gain=8.0, ) samples = soft_limiter(samples, drive=1.10) pcm16 = (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16) return AudioSegment( data=pcm16.tobytes(), sample_width=2, frame_rate=target_sr, channels=1, ) def normalize_speaker_label(label) -> str: txt = str(label).strip() if not txt: return "SPEAKER_0" if txt.isdigit(): return f"SPEAKER_{txt}" up = txt.upper().replace(" ", "_") if up.startswith("SPEAKER"): return up return up def _parse_rttm_line(line: str): parts = line.strip().split() if len(parts) < 8 or parts[0].upper() != "SPEAKER": return None start = _try_float(parts[3]) dur = _try_float(parts[4]) speaker = parts[7] if start is None or dur is None or dur <= 0: return None return { "start": start, "end": start + dur, "speaker": normalize_speaker_label(speaker), } def _parse_simple_segment_line(line: str): # Handles: "0.00, 1.24, 2" or "0.00 1.24 SPEAKER_2" cleaned = line.strip().replace(",", " ") parts = [p for p in cleaned.split() if p] if len(parts) < 3: return None start = _try_float(parts[0]) end = _try_float(parts[1]) speaker = parts[2] if start is None or end is None or end <= start: return None return {"start": start, "end": end, "speaker": normalize_speaker_label(speaker)} def parse_diarization_output(raw_output, audio_duration_sec=None) -> list: parsed = [] def append_seg(start, end, speaker): s = _try_float(start) e = _try_float(end) if s is None or e is None or e <= s: return parsed.append( {"start": s, "end": e, "speaker": normalize_speaker_label(speaker)} ) def walk(obj): if obj is None: return if isinstance(obj, Path): if obj.exists() and obj.suffix.lower() == ".rttm": with open(obj, "r", encoding="utf-8") as f: for line in f: seg = _parse_rttm_line(line) or _parse_simple_segment_line(line) if seg: parsed.append(seg) return if isinstance(obj, str): maybe_path = Path(obj) if maybe_path.exists() and maybe_path.suffix.lower() == ".rttm": walk(maybe_path) return if "\n" in obj: for line in obj.splitlines(): seg = _parse_rttm_line(line) or _parse_simple_segment_line(line) if seg: parsed.append(seg) return seg = _parse_rttm_line(obj) or _parse_simple_segment_line(obj) if seg: parsed.append(seg) return if isinstance(obj, dict): start = obj.get("start", obj.get("start_time", obj.get("begin"))) end = obj.get("end", obj.get("end_time", obj.get("stop"))) dur = obj.get("duration") speaker = obj.get("speaker", obj.get("speaker_id", obj.get("label", "0"))) if end is None and start is not None and dur is not None: s = _try_float(start) d = _try_float(dur) if s is not None and d is not None: end = s + d if start is not None and end is not None: append_seg(start, end, speaker) for v in obj.values(): walk(v) return if isinstance(obj, (list, tuple)): if ( len(obj) >= 3 and _try_float(obj[0]) is not None and _try_float(obj[1]) is not None ): append_seg(obj[0], obj[1], obj[2]) return for item in obj: walk(item) return if hasattr(obj, "start") and hasattr(obj, "end"): append_seg( getattr(obj, "start"), getattr(obj, "end"), getattr(obj, "speaker", "0") ) walk(raw_output) if parsed and audio_duration_sec: max_end = max(seg["end"] for seg in parsed) # Guard for millisecond outputs if max_end > audio_duration_sec * 20: for seg in parsed: seg["start"] /= 1000.0 seg["end"] /= 1000.0 parsed.sort(key=lambda x: (x["start"], x["end"])) # De-duplicate exact repeats deduped = [] seen = set() for seg in parsed: key = (round(seg["start"], 3), round(seg["end"], 3), seg["speaker"]) if key not in seen: seen.add(key) deduped.append(seg) return deduped def merge_adjacent_speaker_segments(segments: list, max_gap_sec: float = 0.15) -> list: if not segments: return [] merged = [segments[0].copy()] for seg in segments[1:]: last = merged[-1] if ( seg["speaker"] == last["speaker"] and seg["start"] - last["end"] <= max_gap_sec ): last["end"] = max(last["end"], seg["end"]) else: merged.append(seg.copy()) return merged def merge_consecutive_transcript_rows(rows: list) -> list: if not rows: return [] merged = [rows[0].copy()] for row in rows[1:]: last = merged[-1] if row.get("speaker") == last.get("speaker"): last["end"] = max(float(last["end"]), float(row["end"])) prev_text = (last.get("segment") or "").strip() cur_text = (row.get("segment") or "").strip() if prev_text and cur_text: last["segment"] = f"{prev_text} {cur_text}" elif cur_text: last["segment"] = cur_text else: merged.append( { "start": float(row["start"]), "end": float(row["end"]), "speaker": row.get("speaker", "N/A"), "segment": (row.get("segment") or "").strip(), } ) return merged def transcribe_with_segments_and_words(transcribe_path: str): output = model.transcribe([transcribe_path], timestamps=True) if ( not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], "timestamp") or not output[0].timestamp or "segment" not in output[0].timestamp ): raise RuntimeError("Transcription failed or unexpected output format.") timestamp_payload = output[0].timestamp segments = [] for ts in timestamp_payload.get("segment", []): start = _try_float(ts.get("start")) end = _try_float(ts.get("end")) text = str(ts.get("segment", ts.get("text", ""))).strip() if start is None or end is None or end <= start: continue segments.append( { "start": float(start), "end": float(end), "speaker": "N/A", "segment": text, } ) words = [] for w in timestamp_payload.get("word", []): if isinstance(w, dict): start = _try_float(w.get("start", w.get("start_time", w.get("begin")))) end = _try_float(w.get("end", w.get("end_time", w.get("stop")))) token = str(w.get("word", w.get("token", w.get("text", "")))).strip() elif isinstance(w, (list, tuple)) and len(w) >= 3: start = _try_float(w[0]) end = _try_float(w[1]) token = str(w[2]).strip() else: start = _try_float(getattr(w, "start", None)) end = _try_float(getattr(w, "end", None)) token = str(getattr(w, "word", getattr(w, "text", ""))).strip() if start is None or end is None or end <= start: continue words.append({"start": float(start), "end": float(end), "token": token}) return segments, words def transcribe_default_with_timestamps(transcribe_path: str): segments, _ = transcribe_with_segments_and_words(transcribe_path) return segments def _overlap_seconds( a_start: float, a_end: float, b_start: float, b_end: float ) -> float: return max(0.0, min(a_end, b_end) - max(a_start, b_start)) def _join_tokens(tokens: list) -> str: return " ".join(t for t in tokens if t).strip() def split_asr_by_diarization_segments( asr_segments: list, diar_segments: list, asr_words: list = None ) -> list: if not diar_segments: return [] diar_segments = sorted(diar_segments, key=lambda x: (x["start"], x["end"])) # Preferred: word-level mapping to each diar segment. if asr_words: words = sorted(asr_words, key=lambda x: (x["start"], x["end"])) rows_word = [] word_idx = 0 for d in diar_segments: d_start = float(d["start"]) d_end = float(d["end"]) while word_idx < len(words) and words[word_idx]["end"] <= d_start: word_idx += 1 scan = word_idx tokens = [] while scan < len(words) and words[scan]["start"] < d_end: w = words[scan] if _overlap_seconds(d_start, d_end, w["start"], w["end"]) > 0: tokens.append(w["token"]) scan += 1 rows_word.append( { "start": d_start, "end": d_end, "speaker": d["speaker"], "segment": _join_tokens(tokens), } ) if any((r.get("segment") or "").strip() for r in rows_word): return rows_word # Fallback: segment-level overlap assignment. buckets = [[] for _ in diar_segments] for s in asr_segments: s_start = float(s["start"]) s_end = float(s["end"]) txt = (s.get("segment") or "").strip() if not txt: continue best_i = -1 best_ov = 0.0 for i, d in enumerate(diar_segments): ov = _overlap_seconds(s_start, s_end, float(d["start"]), float(d["end"])) if ov > best_ov: best_ov = ov best_i = i if best_i >= 0: buckets[best_i].append(txt) rows = [] for i, d in enumerate(diar_segments): rows.append( { "start": float(d["start"]), "end": float(d["end"]), "speaker": d["speaker"], "segment": " ".join(buckets[i]).strip(), } ) return rows def _clean_token_spacing(text: str) -> str: text = re.sub(r"\s+([.,!?;:])", r"\1", text) return re.sub(r"\s+", " ", text).strip() def _capitalize_first_alpha(text: str) -> str: return re.sub( r"^([^A-Za-z]*)([a-z])", lambda m: m.group(1) + m.group(2).upper(), text ) def _capitalize_after_full_stop(text: str) -> str: # Capitalize first latin letter after ". " return re.sub(r"(?<=\.\s)([a-z])", lambda m: m.group(1).upper(), text) UZ_CARDINAL = { "nol": 0, "bir": 1, "ikki": 2, "uch": 3, "to'rt": 4, "tort": 4, "besh": 5, "olti": 6, "yetti": 7, "sakkiz": 8, "to'qqiz": 9, "toqqiz": 9, "o'n": 10, "on": 10, "yigirma": 20, "o'ttiz": 30, "ottiz": 30, "qirq": 40, "ellik": 50, "oltmish": 60, "yetmish": 70, "sakson": 80, "to'qson": 90, "toqson": 90, "o'nbir": 11, "onbir": 11, "o'nikki": 12, "onikki": 12, "o'nuch": 13, "onuch": 13, "o'nto'rt": 14, "ontort": 14, "o'nbesh": 15, "onbesh": 15, "o'nolti": 16, "onolti": 16, "o'nyetti": 17, "onyetti": 17, "o'nsakkiz": 18, "onsakkiz": 18, "o'nto'qqiz": 19, "ontoqqiz": 19, } UZ_SCALES = {"ming": 1000, "million": 1_000_000, "milliard": 1_000_000_000} UZ_ORDINAL_TO_CARDINAL = { "birinchi": "bir", "ikkinchi": "ikki", "uchinchi": "uch", "to'rtinchi": "to'rt", "tortinchi": "to'rt", "beshinchi": "besh", "oltinchi": "olti", "yettinchi": "yetti", "sakkizinchi": "sakkiz", "to'qqizinchi": "to'qqiz", "toqqizinchi": "to'qqiz", "o'ninchi": "o'n", "oninchi": "o'n", } UZ_MONTHS_PATTERN = ( r"yanvar|fevral|mart|aprel|may|iyun|iyul|avgust|sentabr|oktabr|noyabr|dekabr" ) _TOKEN_CORE_RE = re.compile( r"^([^A-Za-z0-9'`ʻʼ’‘]*)([A-Za-z0-9'`ʻʼ’‘]+)([^A-Za-z0-9'`ʻʼ’‘]*)$" ) def _normalize_uz_word(word: str) -> str: w = str(word).lower() w = ( w.replace("’", "'") .replace("‘", "'") .replace("`", "'") .replace("ʻ", "'") .replace("ʼ", "'") ) repl = { "tort": "to'rt", "toqqiz": "to'qqiz", "on": "o'n", "ottiz": "o'ttiz", "toqson": "to'qson", } return repl.get(w, w) def _is_uz_number_like(word: str) -> bool: if not word: return False if ( word in UZ_CARDINAL or word in UZ_SCALES or word == "yuz" or word in UZ_ORDINAL_TO_CARDINAL ): return True return re.match(r"^.+(?:inchi|nchi)$", word) is not None def _split_token(token: str): m = _TOKEN_CORE_RE.match(token) if not m: return "", "", token prefix, core, suffix = m.group(1), m.group(2), m.group(3) core_norm = _normalize_uz_word(core) # Handle attached clitics on number words, e.g. "yuzmi" -> "yuz" + "mi". clitic = "" for c in ("mi",): if core_norm.endswith(c): stem = core_norm[: -len(c)] if _is_uz_number_like(stem): core_norm = stem clitic = c break return prefix, core_norm, f"{clitic}{suffix}" def _parse_uz_cardinal(words): total = 0 current = 0 seen = False for raw in words: w = _normalize_uz_word(raw) if w == "va": continue if w in UZ_CARDINAL: current += UZ_CARDINAL[w] seen = True elif w == "yuz": current = (current or 1) * 100 seen = True elif w in UZ_SCALES: scale = UZ_SCALES[w] if current == 0: current = 1 total += current * scale current = 0 seen = True else: return None return (total + current) if seen else None def _parse_uz_ordinal(words): if not words: return None normalized = [_normalize_uz_word(w) for w in words] last = normalized[-1] if last in UZ_ORDINAL_TO_CARDINAL: base = normalized[:-1] + [UZ_ORDINAL_TO_CARDINAL[last]] return _parse_uz_cardinal(base) m = re.match(r"^(.+?)(?:inchi|nchi)$", last) if m: stem = _normalize_uz_word(m.group(1)) base = normalized[:-1] + [stem] return _parse_uz_cardinal(base) return None def normalize_uzbek_numbers_in_text(text: str) -> str: if not text: return text tokens = text.split(" ") out = [] i = 0 n = len(tokens) while i < n: p, core, s = _split_token(tokens[i]) if not core: out.append(tokens[i]) i += 1 continue best_kind = None best_val = None best_end = -1 words = [] j = i while j < n: pj, cj, sj = _split_token(tokens[j]) if not cj: break if j > i: _, _, prev_suffix = _split_token(tokens[j - 1]) if prev_suffix: break words.append(cj) card = _parse_uz_cardinal(words) if card is not None: best_kind = "card" best_val = card best_end = j ordv = _parse_uz_ordinal(words) if ordv is not None: best_kind = "ord" best_val = ordv best_end = j if sj: break j += 1 if best_end < i: out.append(tokens[i]) i += 1 continue first_prefix, _, _ = _split_token(tokens[i]) _, _, last_suffix = _split_token(tokens[best_end]) repl = str(best_val) if best_kind == "card" else f"{best_val}-chi" out.append(f"{first_prefix}{repl}{last_suffix}") i = best_end + 1 return " ".join(out) def normalize_uzbek_date_forms(text: str) -> str: # 5 may -> 5-may, 5-chi may -> 5-may text = re.sub( rf"\b(\d+)(?:-chi)?\s+({UZ_MONTHS_PATTERN})\b", lambda m: f"{m.group(1)}-{m.group(2)}", text, flags=re.IGNORECASE, ) # 2024 chi yil / 2024-chi yil / 2024 yil -> 2024-yil text = re.sub(r"\b(\d+)\s*-\s*chi\s+yil\b", r"\1-yil", text, flags=re.IGNORECASE) text = re.sub(r"\b(\d+)\s+chi\s+yil\b", r"\1-yil", text, flags=re.IGNORECASE) text = re.sub(r"\b(\d{3,4})\s+yil\b", r"\1-yil", text, flags=re.IGNORECASE) return text def postprocess_segment_texts( segment_timestamps: list, diarization_enabled: bool ) -> list: for ts in segment_timestamps: txt = str(ts.get("segment", "") or "") txt = _clean_token_spacing(txt) txt = normalize_uzbek_numbers_in_text(txt) txt = normalize_uzbek_date_forms(txt) if diarization_enabled: txt = _capitalize_first_alpha(txt) # each speaker row starts capitalized txt = _capitalize_after_full_stop(txt) # after ". " next letter is uppercase ts["segment"] = txt return segment_timestamps def resolve_player_audio_path(prepared_path, fallback_path: str) -> str: try: if prepared_path and Path(prepared_path).exists(): return Path(prepared_path).as_posix() except Exception: pass return fallback_path @spaces.GPU def get_transcripts_and_raw_times( audio_path, session_dir, use_preprocessing=True, use_diarization=False ): if not audio_path: gr.Error("No audio file path provided for transcription.", duration=None) return ( [], [], None, gr.DownloadButton(label="Download Transcript (CSV)", visible=False), gr.DownloadButton(label="Download Transcript (SRT)", visible=False), ) vis_data = [["N/A", "N/A", "N/A", "Processing failed"]] raw_times_data = [[0.0, 0.0]] processed_audio_path = None diar_audio_path = None playback_audio_path = None original_path_name = Path(audio_path).name audio_name = Path(audio_path).stem csv_button_update = gr.DownloadButton( label="Download Transcript (CSV)", visible=False ) srt_button_update = gr.DownloadButton( label="Download Transcript (SRT)", visible=False ) transcribe_path = audio_path info_path_name = original_path_name try: gr.Info(f"Loading audio: {original_path_name}", duration=2) audio = AudioSegment.from_file(audio_path) duration_sec = audio.duration_seconds # Keep stable playback source in session dir. try: playback_audio_path = Path(session_dir, f"{audio_name}_playback.wav") audio.export(playback_audio_path, format="wav") except Exception as playback_e: playback_audio_path = None gr.Warning(f"Could not prepare playback audio: {playback_e}", duration=5) if use_preprocessing: try: gr.Info( "Preprocessing enabled: mono + denoise + phone-band + dynamic RMS + 16kHz...", duration=3, ) processed_audio = preprocess_audio_for_transcription( audio=audio, target_sr=16000, frame_ms=500, target_rms_db=-20.0 ) processed_audio_path = Path( session_dir, f"{audio_name}_asr_preprocessed.wav" ) processed_audio.export(processed_audio_path, format="wav") transcribe_path = processed_audio_path.as_posix() info_path_name = f"{original_path_name} (preprocessed)" except Exception as preprocess_e: gr.Warning( f"Preprocessing failed ({preprocess_e}). Falling back to original audio.", duration=6, ) transcribe_path = audio_path info_path_name = original_path_name else: gr.Info("Preprocessing disabled. Using original audio.", duration=2) long_audio_settings_applied = False try: model.to(device) model.to(torch.float32) gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2) if duration_sec > 480: try: gr.Info( "Audio longer than 8 minutes. Applying long audio settings.", duration=3, ) model.change_attention_model("rel_pos_local_attn", [256, 256]) model.change_subsampling_conv_chunking_factor(1) long_audio_settings_applied = True except Exception as setting_e: gr.Warning( f"Could not apply long audio settings: {setting_e}", duration=5 ) if device == "cuda": model.to(torch.bfloat16) segment_timestamps = [] if use_diarization: try: gr.Info("Running ASR and diarization in parallel...", duration=3) diar_input_path = audio_path dmodel = get_diar_model() dmodel.to(device) dmodel.to(torch.float32) def _run_asr(): return transcribe_with_segments_and_words(transcribe_path) def _run_diar(): try: diar_output_local = dmodel.diarize( audio=diar_input_path, batch_size=1 ) except TypeError: diar_output_local = dmodel.diarize( audio=[diar_input_path], batch_size=1 ) diar_segments_local = parse_diarization_output( diar_output_local, audio_duration_sec=duration_sec, ) diar_segments_local = merge_adjacent_speaker_segments( diar_segments_local, max_gap_sec=0.15 ) return diar_segments_local, diar_output_local with ThreadPoolExecutor(max_workers=2) as pool: asr_future = pool.submit(_run_asr) diar_future = pool.submit(_run_diar) asr_segments, asr_words = asr_future.result() diar_segments, diar_output = diar_future.result() if not diar_segments: gr.Warning( f"Diarization parsed no segments. Using ASR segmentation. raw_type={type(diar_output)}", duration=7, ) segment_timestamps = asr_segments else: # Split text by diarization segments, then merge consecutive same-speaker rows. segment_timestamps = split_asr_by_diarization_segments( asr_segments=asr_segments, diar_segments=diar_segments, asr_words=asr_words, ) segment_timestamps = merge_consecutive_transcript_rows( segment_timestamps ) if not segment_timestamps: gr.Warning( "No aligned diarized rows. Using ASR segmentation.", duration=7, ) segment_timestamps = asr_segments gr.Info("Diarization + ASR complete.", duration=2) except Exception as diar_e: gr.Warning( f"Diarization failed: {diar_e}. Using standard ASR segmentation.", duration=7, ) segment_timestamps = transcribe_default_with_timestamps( transcribe_path ) else: segment_timestamps = transcribe_default_with_timestamps(transcribe_path) segment_timestamps = postprocess_segment_texts( segment_timestamps, diarization_enabled=use_diarization, ) vis_data = [ [ round(float(ts["start"]), 2), round(float(ts["end"]), 2), ts.get("speaker", "N/A"), ts.get("segment", ""), ] for ts in segment_timestamps ] raw_times_data = [ [float(ts["start"]), float(ts["end"])] for ts in segment_timestamps ] try: csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv") with open(csv_file_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["Start (s)", "End (s)", "Speaker", "Segment"]) writer.writerows(vis_data) csv_button_update = gr.DownloadButton( value=csv_file_path, visible=True, label="Download Transcript (CSV)" ) except Exception as csv_e: gr.Error( f"Failed to create transcript CSV file: {csv_e}", duration=None ) if segment_timestamps: try: srt_content = generate_srt_content(segment_timestamps) srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt") with open(srt_file_path, "w", encoding="utf-8") as f: f.write(srt_content) srt_button_update = gr.DownloadButton( value=srt_file_path, visible=True, label="Download Transcript (SRT)", ) except Exception as srt_e: gr.Warning( f"Failed to create transcript SRT file: {srt_e}", duration=5 ) gr.Info("Transcription complete.", duration=2) return ( vis_data, raw_times_data, resolve_player_audio_path(playback_audio_path, audio_path), csv_button_update, srt_button_update, ) except torch.cuda.OutOfMemoryError: error_msg = "CUDA out of memory. Try shorter audio or reduce GPU load." gr.Error(error_msg, duration=None) return ( [["OOM", "OOM", "N/A", error_msg]], [[0.0, 0.0]], resolve_player_audio_path(playback_audio_path, audio_path), csv_button_update, srt_button_update, ) except FileNotFoundError: gr.Error( f"Audio file not found for transcription: {Path(transcribe_path).name}", duration=None, ) return ( [["Error", "Error", "N/A", "File not found for transcription"]], [[0.0, 0.0]], resolve_player_audio_path(playback_audio_path, audio_path), csv_button_update, srt_button_update, ) except Exception as e: gr.Error(f"Transcription failed: {e}", duration=None) return ( [["Error", "Error", "N/A", f"Transcription failed: {e}"]], [[0.0, 0.0]], resolve_player_audio_path(playback_audio_path, audio_path), csv_button_update, srt_button_update, ) finally: try: if long_audio_settings_applied: try: model.change_attention_model("rel_pos") model.change_subsampling_conv_chunking_factor(-1) except Exception as revert_e: gr.Warning( f"Issue reverting model settings: {revert_e}", duration=5 ) if device == "cuda": model.cpu() if diar_model is not None: diar_model.cpu() gc.collect() if device == "cuda": torch.cuda.empty_cache() except Exception as cleanup_e: gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5) finally: for tmp_path in [processed_audio_path, diar_audio_path]: if tmp_path and os.path.exists(tmp_path): try: os.remove(tmp_path) except Exception as e: print(f"Error removing temporary audio file {tmp_path}: {e}") def play_segment(raw_ts_list, current_audio_path, evt: gr.SelectData): if not isinstance(raw_ts_list, list) or not current_audio_path: return gr.update(value=None, label="Selected Segment") if evt is None or evt.index is None: return gr.update(value=None, label="Selected Segment") if isinstance(evt.index, (list, tuple)): if not evt.index: return gr.update(value=None, label="Selected Segment") selected_index = int(evt.index[0]) else: selected_index = int(evt.index) if selected_index < 0 or selected_index >= len(raw_ts_list): return gr.update(value=None, label="Selected Segment") selected_row = raw_ts_list[selected_index] if not isinstance(selected_row, (list, tuple)) or len(selected_row) != 2: return gr.update(value=None, label="Selected Segment") start_time_s = _try_float(selected_row[0]) end_time_s = _try_float(selected_row[1]) if start_time_s is None or end_time_s is None or end_time_s <= start_time_s: return gr.update(value=None, label="Selected Segment") segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s) if segment_data: return gr.update( value=segment_data, autoplay=True, label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s", ) return gr.update(value=None, label="Selected Segment") article = ( "
Optional preprocessing and optional speaker diarization are supported.
" "