Text-to-Image
Diffusion Single File
English
stable-diffusion

TemporalMesh Transformer: dynamic graph attention + per-token exit gates = 29.4 PPL at 48% compute (

#225
by vigneshwar234 - opened

New open-source transformer: TemporalMesh Transformer (TMT v3) โ€” 29.4 PPL, 48% compute, 5 innovations unified

I've been building a new transformer architecture and wanted to share it with this community.

The single-sentence pitch: TMT is the first transformer to unify dynamic graph attention, semantic temporal decay, and per-token adaptive depth routing in one differentiable forward pass.


Why existing approaches fall short

  • Longformer/BigBird: Sparse attention โœ…, but topology is fixed โ€” can't adapt to what tokens mean
  • Mamba/RWKV: Linear time โœ…, but no cross-token pairwise attention โ€” loses expressivity on reasoning tasks
  • MoE: High capacity โœ…, but every token still uses the same depth โ€” punctuation gets same compute as rare words
  • Depth-adaptive: Variable depth โœ…, but uses full O(Sยฒ) attention โ€” doesn't help with memory bottleneck

TMT fixes all four simultaneously with 5 innovations that reinforce each other (superadditive gains).


Architecture (5 innovations)

1. Mesh Attention โ€” after each layer, recompute a kNN graph from current token representations (cosine similarity, top-k=8). Attention only over these k neighbours. O(Sยทk) per layer. Graph topology changes as representations evolve โ€” by layer 8, "bank" clusters with "river/shore" in a geographical context, not with "credit/loan".

2. Temporal Decay Encoding โ€” multiply a learned per-head scalar into post-softmax attention weights: รฃ_ij = ฮฑ_ij ร— ฯƒ(wยท|t_iโˆ’t_j|). Unlike ALiBi (additive to logits), this is multiplicative post-softmax โ€” stronger attenuation, gradient-stable, each head learns its own decay rate.

3. Adaptive Depth Routing โ€” a per-token exit gate c_i = ฯƒ(W_gateยทx_i). Tokens with confidence > 0.85 freeze and skip all remaining layers. Result: punctuation exits at layer 2.1, rare words at layer 11.7, average 5.76/12 layers โ†’ 52% compute saved.

4. Dual-Stream FFN โ€” parallel syntax and semantic MLP streams, blended by a sigmoid gate. Interpretable: "the" weights syntax stream ~0.7, "serendipity" weights semantic stream ~0.8.

5. EMA Memory Anchors โ€” 16 persistent key-value pairs updated via EMA (ฮฒ=0.99) across training. Fast-weight cross-sequence recall without recurrence. Only 32KB extra parameters.


Numbers

Benchmark Vanilla Mamba RWKV TMT
WikiText-2 PPL โ†“ 42.1 31.8 33.1 29.4
WikiText-103 PPL โ†“ 51.3 38.4 40.9 36.1
LongBench โ†‘ 41.2 51.3 48.7 53.4
C4 PPL โ†“ 38.4 30.1 29.3 27.4
Throughput (A100 FP16) 94K/s 148K/s 160K/s 138K/s
VRAM at S=4096 OOM 12GB 8GB 18GB

All ~120M params. TMT is 138K TPS โ€” slower than pure SSMs but 48% compute vs vanilla transformer at better PPL than either.

Superadditive interaction

Sum of individual gains (Mesh + Decay + Exit) = 8.6 PPL. Combined TMT = 12.7 PPL gain. Interaction effect = +4.1 PPL โ€” the innovations reinforce each other. Mesh produces semantic clusters that make Decay more precise; cleaner Mesh representations make Exit gates fire with higher confidence.


Resources

Happy to discuss architecture choices, the ablations, or the math. All results reproducible with provided seeds (42/1337/2024) on A100 80GB.

Sign up or log in to comment