NeuroScope / app.py
Alogotron's picture
Upload app.py with huggingface_hub
d7c3cd7 verified
#!/usr/bin/env python3
"""
NeuroScope — Neural Network Activation Visualizer
Interactive Gradio dashboard for visualizing LLM hidden states, attention
patterns, and activation maps during inference on Qwen3-4B.
Run locally (demo mode — no GPU required):
python app.py
Run with real model:
python app.py --model
Tabs:
- Analyze: Single-prompt analysis with 4 core views + fingerprinting
- Compare: Side-by-side comparison of two prompts
- Generate: Streaming token-by-token generation with live activations
Part of the Alogotron project: https://huggingface.co/Alogotron
"""
import sys
import os
import argparse
import time
# Ensure local imports work regardless of cwd
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import gradio as gr
from extraction import ActivationExtractor, ExtractionResult
from viz_attention import create_attention_heatmap, get_head_choices
from viz_magnitude import create_magnitude_chart
from viz_token_layer import create_token_layer_grid
from viz_scatter import create_scatter_plot
from viz_fingerprint import create_fingerprint_strip, create_fingerprint_comparison
from viz_comparison import (
create_attention_comparison,
create_magnitude_comparison,
create_token_layer_comparison,
create_scatter_comparison,
)
# ---------------------------------------------------------------------------
# Theme & styling
# ---------------------------------------------------------------------------
ACCENT = "#e6b800"
BG_DARK = "#1a1a2e"
TEXT = "#e0e0e0"
CUSTOM_CSS = """
/* Global dark background overrides */
.gradio-container { background-color: #0f0f23 !important; }
footer { display: none !important; }
/* Header branding */
.neuroscope-header {
text-align: center;
padding: 12px 0 4px;
}
.neuroscope-header h1 {
color: #e6b800;
font-size: 2em;
margin: 0;
letter-spacing: 2px;
}
.neuroscope-header p {
color: #e0e0e0;
opacity: 0.7;
margin: 4px 0 0;
font-size: 0.9em;
}
/* Status badge styling */
.status-bar {
font-family: monospace;
font-size: 0.85em;
padding: 6px 12px;
border-radius: 6px;
background: #16162b;
border: 1px solid #2a2a4e;
}
/* Plot containers — remove extra padding */
.plot-container .js-plotly-plot { margin: 0 !important; }
/* Control panel styling */
.control-panel {
border: 1px solid #2a2a4e;
border-radius: 8px;
padding: 8px;
background: #16162b;
}
/* Generated text display */
.gen-text-display {
font-family: 'Courier New', monospace;
font-size: 1.1em;
line-height: 1.6;
padding: 12px;
background: #16162b;
border: 1px solid #2a2a4e;
border-radius: 8px;
color: #e0e0e0;
min-height: 60px;
}
.gen-text-display .new-token {
color: #e6b800;
font-weight: bold;
}
"""
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
extractor = ActivationExtractor()
current_result: ExtractionResult | None = None
compare_result_a: ExtractionResult | None = None
compare_result_b: ExtractionResult | None = None
def get_status_text(result: ExtractionResult | None, model_loaded: bool) -> str:
"""Generate status bar markdown."""
if result is None:
model_status = "✅ Model loaded" if model_loaded else "💤 Demo mode (no GPU)"
return f"**Status:** {model_status} — Enter a prompt and click Run"
mode = "🧪 Demo Data" if result.is_demo else "🧠 Real Inference"
return (
f"**Status:** {mode} | "
f"⏱ {result.inference_time:.3f}s | "
f"📝 {len(result.tokens)} tokens | "
f"📊 {result.num_layers} layers × {result.num_heads} heads × {result.hidden_dim}d"
)
# ---------------------------------------------------------------------------
# Tab 1: Analyze — callbacks
# ---------------------------------------------------------------------------
def run_inference(prompt: str):
"""Extract activations from the real model."""
global current_result
if not prompt.strip():
prompt = "The quick brown fox jumps over the lazy dog"
if not extractor.model_loaded:
gr.Warning("Model not loaded — using demo data instead.")
return run_demo(prompt)
try:
current_result = extractor.extract(prompt)
except Exception as e:
gr.Warning(f"Inference failed: {e}. Falling back to demo data.")
current_result = ActivationExtractor.generate_demo_data(prompt)
return _build_all_outputs(current_result)
def run_demo(prompt: str):
"""Generate demo data (no GPU required)."""
global current_result
if not prompt.strip():
prompt = "The quick brown fox jumps over the lazy dog"
current_result = ActivationExtractor.generate_demo_data(prompt)
return _build_all_outputs(current_result)
def update_attention(layer: int, head: str):
"""Update attention heatmap on layer/head change."""
if current_result is None:
return _empty_plot("Run inference first")
return create_attention_heatmap(current_result, layer=int(layer), head=head)
def update_magnitude(metric: str):
"""Update magnitude chart on metric change."""
if current_result is None:
return _empty_plot("Run inference first")
return create_magnitude_chart(current_result, metric=metric)
def update_token_grid(normalize: str):
"""Update token-layer grid on normalization change."""
if current_result is None:
return _empty_plot("Run inference first")
return create_token_layer_grid(current_result, normalize=normalize)
def update_scatter(layer: int, method: str, overlay: str):
"""Update scatter plot on layer/method change."""
if current_result is None:
return _empty_plot("Run inference first")
return create_scatter_plot(
current_result,
layer=int(layer),
method=method,
overlay_layers=overlay,
)
def _build_all_outputs(result: ExtractionResult):
"""Build all plot outputs + status from an ExtractionResult."""
fig_attn = create_attention_heatmap(result, layer=0, head="average")
fig_mag = create_magnitude_chart(result, metric="mean_l2")
fig_grid = create_token_layer_grid(result, normalize="global")
fig_scatter = create_scatter_plot(result, layer=18, method="pca")
fig_fp = create_fingerprint_strip(result)
status = get_status_text(result, extractor.model_loaded)
return fig_attn, fig_mag, fig_grid, fig_scatter, fig_fp, status
def _empty_plot(message: str):
"""Return a blank Plotly figure with a centered message."""
import plotly.graph_objects as go
fig = go.Figure()
fig.add_annotation(
text=message,
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(color=TEXT, size=16),
)
fig.update_layout(
paper_bgcolor=BG_DARK,
plot_bgcolor=BG_DARK,
xaxis=dict(visible=False),
yaxis=dict(visible=False),
height=400,
)
return fig
# ---------------------------------------------------------------------------
# Tab 2: Compare — callbacks
# ---------------------------------------------------------------------------
def run_compare(prompt_a: str, prompt_b: str):
"""Run inference on both prompts and build comparison outputs."""
global compare_result_a, compare_result_b
if not prompt_a.strip():
prompt_a = "The quick brown fox jumps over the lazy dog"
if not prompt_b.strip():
prompt_b = "A slow red cat sleeps under the warm sun"
extract_fn = extractor.extract if extractor.model_loaded else ActivationExtractor.generate_demo_data
try:
compare_result_a = extract_fn(prompt_a)
except Exception:
compare_result_a = ActivationExtractor.generate_demo_data(prompt_a)
try:
compare_result_b = extract_fn(prompt_b)
except Exception:
compare_result_b = ActivationExtractor.generate_demo_data(prompt_b)
return _build_compare_outputs(compare_result_a, compare_result_b)
def run_compare_demo(prompt_a: str, prompt_b: str):
"""Generate demo data for both prompts."""
global compare_result_a, compare_result_b
if not prompt_a.strip():
prompt_a = "The quick brown fox jumps over the lazy dog"
if not prompt_b.strip():
prompt_b = "A slow red cat sleeps under the warm sun"
compare_result_a = ActivationExtractor.generate_demo_data(prompt_a)
compare_result_b = ActivationExtractor.generate_demo_data(prompt_b)
return _build_compare_outputs(compare_result_a, compare_result_b)
def update_compare_attention(layer: int, head: str):
if compare_result_a is None or compare_result_b is None:
return _empty_plot("Run comparison first")
return create_attention_comparison(compare_result_a, compare_result_b, layer=int(layer), head=head)
def update_compare_magnitude(metric: str):
if compare_result_a is None or compare_result_b is None:
return _empty_plot("Run comparison first")
return create_magnitude_comparison(compare_result_a, compare_result_b, metric=metric)
def update_compare_grid(normalize: str):
if compare_result_a is None or compare_result_b is None:
return _empty_plot("Run comparison first")
return create_token_layer_comparison(compare_result_a, compare_result_b, normalize=normalize)
def update_compare_scatter(layer: int, method: str):
if compare_result_a is None or compare_result_b is None:
return _empty_plot("Run comparison first")
return create_scatter_comparison(compare_result_a, compare_result_b, layer=int(layer), method=method)
def _build_compare_outputs(result_a: ExtractionResult, result_b: ExtractionResult):
"""Build all comparison plot outputs."""
fig_attn = create_attention_comparison(result_a, result_b, layer=0, head="average")
fig_mag = create_magnitude_comparison(result_a, result_b, metric="mean_l2")
fig_grid = create_token_layer_comparison(result_a, result_b, normalize="global")
fig_scatter = create_scatter_comparison(result_a, result_b, layer=18, method="pca")
fig_fp = create_fingerprint_comparison(result_a, result_b)
mode = "🧪 Demo" if result_a.is_demo else "🧠 Real"
status = (
f"**Comparison:** {mode} | "
f"Prompt A: {len(result_a.tokens)} tokens ({result_a.inference_time:.3f}s) | "
f"Prompt B: {len(result_b.tokens)} tokens ({result_b.inference_time:.3f}s)"
)
return fig_attn, fig_mag, fig_grid, fig_scatter, fig_fp, status
# ---------------------------------------------------------------------------
# Tab 3: Generate — streaming callbacks
# ---------------------------------------------------------------------------
def run_generate(prompt: str, max_tokens: int):
"""Stream token generation with live activation updates."""
if not prompt.strip():
prompt = "Once upon a time"
max_tokens = int(max_tokens)
if extractor.model_loaded:
gen = extractor.generate_streaming(prompt, max_new_tokens=max_tokens)
else:
gen = ActivationExtractor.generate_demo_streaming(prompt, max_new_tokens=max_tokens)
for result in gen:
text_display = " ".join(result.tokens)
fig_mag = create_magnitude_chart(result, metric="mean_l2")
fig_grid = create_token_layer_grid(result, normalize="global")
fig_fp = create_fingerprint_strip(result)
status = (
f"**Generating:** {len(result.tokens)} tokens | "
f"⏱ {result.inference_time:.2f}s | "
f"{'🧪 Demo' if result.is_demo else '🧠 Real'}"
)
yield text_display, fig_mag, fig_grid, fig_fp, status
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_app() -> tuple[gr.Blocks, gr.themes.Base]:
"""Construct the Gradio Blocks interface."""
theme = gr.themes.Base(
primary_hue=gr.themes.colors.yellow,
secondary_hue=gr.themes.colors.blue,
neutral_hue=gr.themes.colors.gray,
font=["Inter", "system-ui", "sans-serif"],
).set(
body_background_fill="#0f0f23",
body_background_fill_dark="#0f0f23",
block_background_fill="#16162b",
block_background_fill_dark="#16162b",
block_border_color="#2a2a4e",
block_border_color_dark="#2a2a4e",
block_title_text_color="#e6b800",
block_title_text_color_dark="#e6b800",
block_label_text_color="#e0e0e0",
block_label_text_color_dark="#e0e0e0",
input_background_fill="#1a1a2e",
input_background_fill_dark="#1a1a2e",
input_border_color="#2a2a4e",
input_border_color_dark="#2a2a4e",
button_primary_background_fill="#e6b800",
button_primary_background_fill_dark="#e6b800",
button_primary_text_color="#0f0f23",
button_primary_text_color_dark="#0f0f23",
button_secondary_background_fill="#2a2a4e",
button_secondary_background_fill_dark="#2a2a4e",
button_secondary_text_color="#e0e0e0",
button_secondary_text_color_dark="#e0e0e0",
)
with gr.Blocks(title="NeuroScope") as app:
# Header
gr.HTML(
'<div class="neuroscope-header">'
'<h1>🧠 NeuroScope</h1>'
'<p>Neural Network Activation Visualizer — '
'See inside Qwen3-4B during inference</p>'
'</div>'
)
# ===================================================================
# TABS
# ===================================================================
with gr.Tabs():
# ===============================================================
# TAB 1: ANALYZE (original single-prompt analysis)
# ===============================================================
with gr.TabItem("🧠 Analyze", id="analyze"):
analyze_status = gr.Markdown(
value=get_status_text(None, extractor.model_loaded),
elem_classes=["status-bar"],
)
with gr.Row():
prompt_box = gr.Textbox(
value="The quick brown fox jumps over the lazy dog",
label="Input Prompt",
placeholder="Enter text to analyze...",
scale=5,
max_lines=3,
)
run_btn = gr.Button("🧠 Run Inference", variant="primary", scale=1)
demo_btn = gr.Button("🧪 Demo Data", variant="secondary", scale=1)
# 2×2 Visualization Grid
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### 🔍 Attention Heatmap")
with gr.Row():
attn_layer = gr.Slider(
minimum=0, maximum=35, step=1, value=0,
label="Layer", scale=2,
)
attn_head = gr.Dropdown(
choices=["average", "max"] + [str(i) for i in range(32)],
value="average",
label="Head", scale=1,
)
plot_attn = gr.Plot(label="Attention")
with gr.Column():
gr.Markdown("### 📊 Activation Magnitude")
mag_metric = gr.Radio(
choices=["mean_l2", "max_l2", "mean_abs"],
value="mean_l2",
label="Metric",
)
plot_mag = gr.Plot(label="Magnitude")
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### 🌡️ Token × Layer Grid")
grid_norm = gr.Radio(
choices=["global", "per_layer", "per_token", "none"],
value="global",
label="Normalization",
)
plot_grid = gr.Plot(label="Token-Layer")
with gr.Column():
gr.Markdown("### 🎯 Token Representation Space")
with gr.Row():
scatter_layer = gr.Slider(
minimum=0, maximum=35, step=1, value=18,
label="Layer", scale=2,
)
scatter_method = gr.Radio(
choices=["pca", "umap"],
value="pca",
label="Method", scale=1,
)
scatter_overlay = gr.Textbox(
value="",
label="Overlay layers (comma-separated, e.g. 0,9,18,27,35)",
placeholder="Leave empty for single layer",
)
plot_scatter = gr.Plot(label="Scatter")
# Fingerprint section
with gr.Accordion("🔑 Activation Fingerprint", open=False):
gr.Markdown(
"Each token gets a unique color derived from PCA of its activation "
"trajectory across all 36 layers. Tokens processed similarly share "
"similar colors. The trajectory heatmap shows raw L2 norms, and the "
"similarity matrix reveals which tokens the network treated alike."
)
plot_fingerprint = gr.Plot(label="Fingerprint")
# About section
with gr.Accordion("ℹ️ About NeuroScope", open=False):
gr.Markdown(
"""**NeuroScope** lets you look inside a large language model while it processes text.
**Views:**
- **Attention Heatmap** — Which tokens attend to which? Select any layer and head,
or view the average pattern across all heads.
- **Activation Magnitude** — How strong are the hidden state activations at each layer?
⭐ Gold bars mark layers 9, 18, 27 (used by the Activation Avatars system).
- **Token × Layer Grid** — A heatmap of every token's activation strength at every layer.
Watch how token representations evolve through the network.
- **Token Representation Space** — PCA (or UMAP) projection of token hidden states.
See how tokens cluster and separate. Use the overlay feature to trace token
trajectories across layers.
- **Activation Fingerprint** — Compact visual identity for each token based on its
full processing trajectory through all layers.
**Model:** Qwen3-4B (36 layers, 32 heads, 2560 hidden dim) |
**Built by:** [Alogotron](https://huggingface.co/Alogotron)
"""
)
# Event wiring — Analyze tab
all_outputs = [plot_attn, plot_mag, plot_grid, plot_scatter, plot_fingerprint, analyze_status]
run_btn.click(
fn=run_inference,
inputs=[prompt_box],
outputs=all_outputs,
)
demo_btn.click(
fn=run_demo,
inputs=[prompt_box],
outputs=all_outputs,
)
prompt_box.submit(
fn=run_demo if not extractor.model_loaded else run_inference,
inputs=[prompt_box],
outputs=all_outputs,
)
attn_layer.change(fn=update_attention, inputs=[attn_layer, attn_head], outputs=[plot_attn])
attn_head.change(fn=update_attention, inputs=[attn_layer, attn_head], outputs=[plot_attn])
mag_metric.change(fn=update_magnitude, inputs=[mag_metric], outputs=[plot_mag])
grid_norm.change(fn=update_token_grid, inputs=[grid_norm], outputs=[plot_grid])
scatter_layer.change(
fn=update_scatter,
inputs=[scatter_layer, scatter_method, scatter_overlay],
outputs=[plot_scatter],
)
scatter_method.change(
fn=update_scatter,
inputs=[scatter_layer, scatter_method, scatter_overlay],
outputs=[plot_scatter],
)
scatter_overlay.submit(
fn=update_scatter,
inputs=[scatter_layer, scatter_method, scatter_overlay],
outputs=[plot_scatter],
)
# ===============================================================
# TAB 2: COMPARE (two-prompt comparison)
# ===============================================================
with gr.TabItem("⚖️ Compare", id="compare"):
compare_status = gr.Markdown(
value="**Compare:** Enter two prompts and click Compare to see activation differences",
elem_classes=["status-bar"],
)
with gr.Row():
with gr.Column(scale=5):
cmp_prompt_a = gr.Textbox(
value="The quick brown fox jumps over the lazy dog",
label="Prompt A (gold)",
placeholder="First prompt...",
max_lines=2,
)
cmp_prompt_b = gr.Textbox(
value="A slow red cat sleeps under the warm sun",
label="Prompt B (blue)",
placeholder="Second prompt...",
max_lines=2,
)
with gr.Column(scale=1):
cmp_run_btn = gr.Button("⚖️ Compare", variant="primary")
cmp_demo_btn = gr.Button("🧪 Demo Compare", variant="secondary")
# Comparison visualizations
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### 🔍 Attention Comparison")
with gr.Row():
cmp_attn_layer = gr.Slider(
minimum=0, maximum=35, step=1, value=0,
label="Layer", scale=2,
)
cmp_attn_head = gr.Dropdown(
choices=["average", "max"] + [str(i) for i in range(32)],
value="average",
label="Head", scale=1,
)
cmp_plot_attn = gr.Plot(label="Attention Comparison")
with gr.Column():
gr.Markdown("### 📊 Magnitude Comparison")
cmp_mag_metric = gr.Radio(
choices=["mean_l2", "max_l2", "mean_abs"],
value="mean_l2",
label="Metric",
)
cmp_plot_mag = gr.Plot(label="Magnitude Comparison")
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### 🌡️ Token×Layer Comparison")
cmp_grid_norm = gr.Radio(
choices=["global", "raw"],
value="global",
label="Normalization",
)
cmp_plot_grid = gr.Plot(label="Grid Comparison")
with gr.Column():
gr.Markdown("### 🎯 Scatter Comparison")
with gr.Row():
cmp_scatter_layer = gr.Slider(
minimum=0, maximum=35, step=1, value=18,
label="Layer", scale=2,
)
cmp_scatter_method = gr.Radio(
choices=["pca", "umap"],
value="pca",
label="Method", scale=1,
)
cmp_plot_scatter = gr.Plot(label="Scatter Comparison")
# Fingerprint comparison
with gr.Accordion("🔑 Fingerprint Comparison", open=False):
gr.Markdown(
"Side-by-side activation trajectory fingerprints. "
"Jointly normalized so both prompts are visually comparable."
)
cmp_plot_fp = gr.Plot(label="Fingerprint Comparison")
# Event wiring — Compare tab
cmp_all_outputs = [cmp_plot_attn, cmp_plot_mag, cmp_plot_grid, cmp_plot_scatter, cmp_plot_fp, compare_status]
cmp_run_btn.click(
fn=run_compare,
inputs=[cmp_prompt_a, cmp_prompt_b],
outputs=cmp_all_outputs,
)
cmp_demo_btn.click(
fn=run_compare_demo,
inputs=[cmp_prompt_a, cmp_prompt_b],
outputs=cmp_all_outputs,
)
cmp_attn_layer.change(
fn=update_compare_attention,
inputs=[cmp_attn_layer, cmp_attn_head],
outputs=[cmp_plot_attn],
)
cmp_attn_head.change(
fn=update_compare_attention,
inputs=[cmp_attn_layer, cmp_attn_head],
outputs=[cmp_plot_attn],
)
cmp_mag_metric.change(
fn=update_compare_magnitude,
inputs=[cmp_mag_metric],
outputs=[cmp_plot_mag],
)
cmp_grid_norm.change(
fn=update_compare_grid,
inputs=[cmp_grid_norm],
outputs=[cmp_plot_grid],
)
cmp_scatter_layer.change(
fn=update_compare_scatter,
inputs=[cmp_scatter_layer, cmp_scatter_method],
outputs=[cmp_plot_scatter],
)
cmp_scatter_method.change(
fn=update_compare_scatter,
inputs=[cmp_scatter_layer, cmp_scatter_method],
outputs=[cmp_plot_scatter],
)
# ===============================================================
# TAB 3: GENERATE (streaming token-by-token)
# ===============================================================
with gr.TabItem("⚡ Generate", id="generate"):
gen_status = gr.Markdown(
value="**Generate:** Enter a prompt and watch activations evolve as the model generates text token-by-token",
elem_classes=["status-bar"],
)
with gr.Row():
gen_prompt = gr.Textbox(
value="Once upon a time",
label="Starting Prompt",
placeholder="Enter text to continue generating from...",
scale=4,
max_lines=2,
)
gen_max_tokens = gr.Slider(
minimum=4, maximum=64, step=4, value=16,
label="Max New Tokens",
scale=1,
)
gen_btn = gr.Button("⚡ Generate", variant="primary", scale=1)
# Generated text display
gen_text = gr.Textbox(
label="Generated Text",
interactive=False,
lines=3,
max_lines=5,
elem_classes=["gen-text-display"],
)
# Live visualizations (subset — most useful for streaming)
gr.Markdown("### 📊 Live Activation Magnitude")
gen_plot_mag = gr.Plot(label="Magnitude (live)")
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("### 🌡️ Live Token × Layer Grid")
gen_plot_grid = gr.Plot(label="Token-Layer (live)")
with gr.Column():
gr.Markdown("### 🔑 Live Fingerprint")
gen_plot_fp = gr.Plot(label="Fingerprint (live)")
# Event wiring — Generate tab
gen_btn.click(
fn=run_generate,
inputs=[gen_prompt, gen_max_tokens],
outputs=[gen_text, gen_plot_mag, gen_plot_grid, gen_plot_fp, gen_status],
)
return app, theme
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="NeuroScope — Activation Visualizer")
parser.add_argument(
"--model", action="store_true",
help="Load Qwen3-4B for real inference (requires GPU)",
)
parser.add_argument(
"--model-name", default="Qwen/Qwen3-4B",
help="HuggingFace model name or path",
)
parser.add_argument(
"--no-quantize", action="store_true",
help="Load model in fp16 instead of 4-bit quantization",
)
parser.add_argument(
"--port", type=int, default=7860,
help="Server port (default: 7860)",
)
parser.add_argument(
"--share", action="store_true",
help="Create a public Gradio share link",
)
args = parser.parse_args()
if args.model:
print("Loading model... this may take a minute.")
status = extractor.load_model(
model_name=args.model_name,
quantize=not args.no_quantize,
)
print(status)
else:
print("Starting in demo mode (no GPU required).")
print("Use --model to load Qwen3-4B for real inference.")
app, theme = build_app()
app.launch(
server_name="0.0.0.0",
server_port=args.port,
share=args.share,
theme=theme,
css=CUSTOM_CSS,
)
if __name__ == "__main__":
main()