Configuration Parsing Warning:In UNKNOWN_FILENAME: "auto_map.AutoTokenizer" must be a string
Cognica-PoE-v1.0-1.3B-base
Paper: Product of Experts as Scalable Local Learning: Modular Construction at 1.3B Parameters (Jeong, 2026)
A 1.384B parameter causal language model pretrained from scratch with Product of Experts (PoE) local learning at Chinchilla-optimal compute (27.7B tokens, ratio 20). The model is trained with per-stage detached CE losses through a shared lm_head, which gives every intermediate stage a valid self-contained predictor β not just the final layer.
This property unlocks a family of inference-time capabilities a standard backprop-trained model cannot access without either retraining or accuracy loss: stage prefix pruning, WAND-style adaptive depth, speculative decoding with zero added parameters, parallel stage branching, and post-hoc specialists that attach to a frozen base.
TL;DR
- 1.384B params, d24, 1536-dim, 12 heads, 4 clustered stages Γ 6 layers
- ~27.7B tokens, Chinchilla ratio 20, 26,430 steps, MuonAdamW, bf16
- PoE flat-mode, Ξ±=0.0, poe_every=6 β four per-stage detached CE losses into a shared
lm_head - Final val BPB:
0.720935at step 26,430 (minimum, monotonic warmdown improvement) - Standard HF
AutoModelForCausalLM+AutoTokenizerwithtrust_remote_code=True - KV cache, BOS-prepend protocol, SDPA attention (runs on any modern GPU, MPS, CPU)
Why PoE β architectural capabilities
Every stage boundary (layers 5, 11, 17, 23) produces a complete valid predictor through the same shared lm_head, because that's what each stage was trained to do. This is not an approximation or an auxiliary structure; the final lm_head geometry is shared by all stages by construction.
Measured on the matching d24 PoE checkpoint (Jeong 2026; r=10 1.3B). Capabilities are structural and transfer to this r=20 release:
| Capability | Speedup | Quality retained | Requires retraining? |
|---|---|---|---|
| Stage 1 prefix pruning (25% compute) | ~4Γ | 87.5% factual accuracy = full model (8-prompt probe) | No |
| WAND adaptive stage pruning | 1.82Γ wall-clock | 100% top-1 agreement (18/18 prompts) | No |
| Speculative decoding (Stage 1 natural drafter) | 1.87Γ | 88% acceptance @ K=3, 13/13 match | No (no separate drafter needed) |
| Parallel stage composition (Log-OP / PoE algebra) | +2.4 logit margin (paper Β§6.5.5) or ~5.8Γ margin on this release probe | quality-positive, sharper joint distribution | No |
| Multi-device parallel dispatch (Stages 2β4 run after shared Stage 1) | T_S1 + max(T_S2,T_S3,T_S4) vs sequential |
identical | No (requires multi-GPU) |
| Post-hoc specialist stages (dual-head SFT) | β | base preserved bit-identically (Ξlogit = 0.0000 across 12 checkpoints) | No (freeze base) |
| Apple Silicon (MLX, M1 Ultra) Stage 1 only | 2.9Γ vs full stack (8.5 ms vs 24.7 ms) | 87.5% factual accuracy | No |
Why a BP-trained model can't do this: in standard backprop, intermediate hidden states aren't trained to be valid outputs β applying the final lm_head to them produces meaningless logits. PoE's per-stage supervision makes each stage a first-class predictor from day one. Early-exit methods for BP-trained models (DeeBERT, FastBERT, Branchynet) require training modifications, additional head parameters, or accept accuracy degradation.
Task-polarized capability profile
On 22-task CORE benchmark vs matched BP baseline (Jeong 2026, r=10 comparison):
- PoE wins (β₯0.03): CommonsenseQA +5.8pp, PIQA +5.0pp, BigBench CS Algorithms +11.4pp, BigBench Operators +2.7pp
- BP wins (β₯0.06): Jeopardy (rare-fact retrieval) β16.2pp, SQuAD β18.4pp, LAMBADA β15.0pp
PoE is not uniformly weaker β it trades rare-fact retrieval for gains in commonsense reasoning and algorithmic pattern recognition. The 6.52% val BPB gap (matched-baseline, final) is a bounded architectural trade-off, not a failure: you exchange 6.52% lm-loss for the capability matrix above. Deployment positioning: datacenter inference favors BP's factual coverage; on-device favors PoE's prefix pruning, WAND, speculative decoding, and retrieval augmentation (which recovers rare-fact retrieval at 712Γ on disambiguated queries).
For theoretical background (Log-OP / WAND / stage-level MoE), measurement protocols, and multi-model SFT experiments, see the companion paper: Product of Experts as Scalable Local Learning (Jeong, 2026).
Usage
Important: every document in the training corpus starts with
<|bos|>, so base-model prompts must prepend the BOS token. Without it the model treats the prompt as mid-document text and generation collapses into repetition. This matches nanochat'sbase_trainsampling protocol (tokenizer(prompt, prepend="<|bos|>")).
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"cognica/Cognica-PoE-v1.0-1.3B-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"cognica/Cognica-PoE-v1.0-1.3B-base",
trust_remote_code=True,
torch_dtype="bfloat16",
).cuda().eval()
def encode_with_bos(prompt: str) -> torch.Tensor:
ids = tokenizer(prompt, return_tensors="pt").input_ids
bos = torch.tensor([[tokenizer.bos_token_id]])
return torch.cat([bos, ids], dim=1).cuda()
# Greedy
input_ids = encode_with_bos("The capital of France is")
out = model.generate(input_ids, max_new_tokens=32, do_sample=False)
print(tokenizer.decode(out[0], skip_special_tokens=True))
# => "The capital of France is Paris, the capital of France. Paris is the capital of France..."
# Nucleus sampling
input_ids = encode_with_bos("The planets of the solar system are:")
out = model.generate(
input_ids, max_new_tokens=64,
do_sample=True, temperature=0.8, top_p=0.9,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
# => "The planets of the solar system are: Mercury, Venus, Earth, Mars, Jupiter, ..."
trust_remote_code=True is required because the model uses a custom architecture (ResFormer-style value embeddings on alternating layers, ReLU^2 MLPs, per-layer residual scalars, smear/backout mechanisms, sliding-window attention pattern, logit softcap). These are implemented in modeling_cognica_poe.py included in this repo. The released wrapper routes attention through PyTorch SDPA so the model runs on Ampere / Hopper / Blackwell / MPS / CPU without Flash Attention 3 installed.
PoE stage-level inference (five modes shipped)
The released modeling_cognica_poe.py exposes PoE stage-level inference directly on CognicaPoEForCausalLM β no forward hooks or extra code required:
# Inspect the PoE structure
print(model.poe_n_stages) # 4
print(model.poe_stage_boundary_layers) # [5, 11, 17, 23]
# --- 1. Stage prefix pruning: generate using only Stage 0 (25 % compute) ---
ids = encode_with_bos("The capital of France is")
out = model.generate_stage(ids, stage=0, max_new_tokens=32, do_sample=False)
print(tokenizer.decode(out[0], skip_special_tokens=True))
# => "The capital of France is Paris, the capital of France. ..."
# --- 2. WAND adaptive stage pruning ---
# Paper Β§5.3: 1.82Γ wall-clock at 100% top-1 agreement (safety=1.0, d24 r=10 calib).
out, stages_used = model.generate_wand(
ids, max_new_tokens=32, safety=1.0, return_stages_used=True,
)
print(f"avg stage used: {sum(stages_used) / len(stages_used):.2f} / {model.poe_n_stages - 1}")
# easy tokens exit at stage 1 or 2; hard tokens go full depth
# --- 3. Speculative decoding (Stage 0 as natural drafter) ---
# Paper Β§5.4: 1.87Γ speedup at 88% acceptance (K=3, d24 r=10). Zero added params.
out, acc = model.generate_speculative(
ids, max_new_tokens=32, draft_stage=0, k_draft=3, return_acceptance=True,
)
print(f"draft acceptance: {acc*100:.1f}%")
# --- 4. Parallel stage composition (quality-positive, Log-OP / PoE algebra) ---
# Paper Β§6.5.5: combining multiple stages in log-space produces a sharper joint
# distribution. Measured on this release: combining all 4 stages widens the
# top-1 vs top-2 margin by ~5.8Γ (0.49 -> 2.84 logit) on a factual probe.
out = model.generate_parallel_composition(
ids, stages=[0, 1, 2, 3], max_new_tokens=32, do_sample=False,
)
# Use stage_weights to tune the balance (paper Β§6.5.6 inference-time knob):
out = model.generate_parallel_composition(
ids, stages=[2, 3], stage_weights=[0.5, 1.0], max_new_tokens=32,
)
# --- 5. Raw stage logits (for custom decoding / analysis) ---
logits = model.forward_stage(ids, stage=0) # (B, T, vocab_size) at layer 5
# or all stages at once (single forward pass):
stage_logits = model.gpt.forward_all_stages(
ids, stage_boundaries=model.poe_stage_boundary_layers,
) # list of 4 logit tensors, ordered by the given boundaries
Notes on KV caching. Standard HF generate() uses the full-depth KV cache implementation in CognicaKVCache for O(T)-per-step decode. The PoE stage methods above re-forward the full prefix each decode step (no KV cache). This keeps the reference implementation small and paper-faithful; a stage-aware KV cache (maintaining per-stage depth windows) is left as an integration detail for production serving stacks. The compute savings from fewer layers per forward are already realised.
Calibration defaults. generate_wand uses paper Β§5.3 p99 bounds (7.09, 3.03, 2.15) as defaults β these were calibrated on the matching d24 PoE architecture. Override via the p99_bounds= argument if you recalibrate on your own probe set.
Architecture
Rewritten GPT (port of nanochat's gpt.py) with:
- Rotary position embeddings (no learned positions,
rope_theta=100000) - QK norm (RMSNorm before attention scoring,
1.2xscale on both Q and K) - Untied word embeddings (separate
lm_head) with logit softcap15 * tanh(x / 15) ReLU^2MLP activation,4 Γ n_embdintermediate- RMSNorm with no learnable affine params; no bias anywhere
- Sliding window attention pattern
SSSL(short-short-short-long, tiled) - ResFormer value embeddings on alternating layers (gated per-head into V)
- Per-layer learnable scalars:
resid_lambdas(residual stream scaling),x0_lambdas(initial-embedding blend-in) - Smear: bigram-style previous-token embedding mix
- Backout: subtracts mid-depth residual before the final norm
Hyperparameters
| Param | Value |
|---|---|
n_layer |
24 |
n_head |
12 |
n_kv_head |
12 (no GQA) |
n_embd |
1536 |
head_dim |
128 |
sequence_len |
2048 |
vocab_size |
32768 (already a multiple of 64, no padding applied) |
window_pattern |
SSSL |
rope_theta |
100000 |
poe_mode |
flat (per-stage detachment) |
poe_every |
6 (β 4 stages) |
poe_alpha |
0.0 (uniform stage average) |
Clustered PoE details
Layers are grouped into $K = \lceil L / S \rceil = 4$ stages of $S = 6$ layers each. During training:
- Intra-stage: full backprop.
- Inter-stage:
detach()at each boundary β Stage $k$'s gradient depends only on its own CE. - All stages project through the same shared
lm_head:
with alpha=0.0 giving uniform average (L_PoE = (1/K) Β· Ξ£ CE_k). The shared W_head receives gradient contributions from every stage, enforcing a common output geometry without an auxiliary coordinator network.
Inference consequence: because every stage was trained to produce valid next-token logits through the same W_head, each intermediate hidden state h_5, h_11, h_17, h_23 is a complete predictor β enabling the prefix-pruning / WAND / speculative / parallel properties in Why PoE. At the final layer the model behaves as a standard causal LM; the auxiliary stages are additional inference paths, not replacements.
Training
- Dataset: NVIDIA ClimbMix (via
karpathy/climbmix-400b-shufflemirror), 700 shards (~35 GB) - Tokens: 27,713,863,680 = 26,430 Γ 1,048,576 (Chinchilla-optimal for 1.384B params, ratio 20)
- Iterations: 26,430
- Batch size: 1,048,576 tokens/step
- Device batch: 16 sequences per GPU
- Optimizer:
MuonAdamW(hybrid Muon + AdamW)embedding_lr = 0.3,unembedding_lr = 0.008,matrix_lr = 0.02weight_decay = 0.28
- Schedule: linear warmup (40 steps), stable, linear warmdown (0.65 ratio) to
0.05 Γ peak - Precision: bf16 compute, fp32 accumulation
- Infrastructure: 8 Γ NVIDIA A100 80GB across two nodes (intra-zone gVNIC DDP, GCP asia-southeast1-c)
- Wall-clock: 65.87 hours
Evaluation snapshots
BPB (bits per byte, tokenization-invariant) on a held-out ClimbMix validation set (40 M tokens), logged every 1,000 steps. The BP baseline below is a separately-trained matched run with identical architecture and hyperparameters (only --poe-mode=none differs).
| Step | BP baseline (no PoE) | PoE alpha=0.0 (this model) |
Gap |
|---|---|---|---|
| 1,000 | 0.847 | 0.883 | +4.32% |
| 6,000 | 0.785 | 0.819 | +4.34% |
| 16,000 | 0.734 | 0.772 | +5.22% |
| 18,000 | 0.722 | 0.762 | +5.47% |
| 20,000 | 0.710 | 0.751 | +5.83% |
| 25,000 | 0.683 | 0.726 | +6.41% |
| 26,430 (final) | 0.676788 | 0.720935 | +6.52% |
- Final PoE val BPB 0.720935 at step 26,430 (minimum; monotonic improvement through warmdown).
- The matched BP baseline (identical architecture, only
--poe-mode=nonediffers) completed training at step 26,430 with BPB 0.676788. Final matched gap: 6.52 % β a bounded architectural trade-off for the capability matrix in Why PoE. - Gap widens convexly through the run (+2.20 pp over 26K steps), with ~31 % of the widening concentrated in the final 6K warmdown steps β the structural signature described in paper Β§5.6: BP's global gradient coordination produces its largest advantage during fine-grained optimization, but the cost is bounded.
Intended use
- On-device / edge inference β the architecture's natural habitat (prefix pruning, WAND, speculative decoding, post-hoc specialists, low activation memory
O(BΒ·TΒ·d)per device). - PoE / local-learning research β cross-stage gradient independence, WAND pruning, parallel stage branching, dual-head specialist SFT.
- Base for post-hoc specialists β attach Stage 5 (chat, factual, code, etc.) via dual-head construction; base Stages 1β4 preserved bit-identically.
- Distillation baseline β BP teacher β PoE student for on-device compression.
- Benchmarks / ablations β reproducible Chinchilla-ratio-20 PoE @ 1.3B.
Not intended for: production chat deployment, factual QA, safety-critical applications. This is a base model with no RLHF, no SFT, no safety tuning. For domain-tuned inference, see the sibling specialist repos listed below.
Related models β specialist stages
The modeling_cognica_poe.py in this repo ships a cascade loader: any stage repo whose config.json sets base_model_name_or_path to this base (or to another stage) will automatically compose at load time. Each stage contributes 2β4 new transformer layers and an additive lm_head_stage; ancestors' stage heads are folded into the effective base head so specialists compose additively into the final projection (logits = lm_head_base + Ξ£ lm_head_stage_k). All four specialists currently published are siblings β each branches directly from this base:
| Stage repo | Depth (new_layers) |
Trainable | Training data | Best val bpb |
|---|---|---|---|---|
-stage-chat |
d26 (2) | 107 M | SmolTalk + MMLU + GSM8K (500 k convs) | 2.0610 @ step 2 600 |
-stage-math |
d28 (4) | 164 M | GSM8K Γ20 + MathInstruct Γ4 (1.20 M convs) | 2.4118 @ step 2 200 |
-stage-code |
d28 (4) | 164 M | CodeAlpaca Γ3 + Magicoder-Evol-Instruct Γ2 (282 k convs) | 2.3610 @ step 2 000 |
-stage-tool |
d28 (4) | 164 M | glaive-function-calling-v2 Γ4 + xlam-function-calling-60k Γ3 (632 k convs) | 1.4288 @ step 2 600 |
The cascade loader supports arbitrary-depth chaining, so any of these stages can in turn serve as a parent for further specialist stages. Note that at 1.3 B the base's capacity bounds emergent reasoning/arithmetic regardless of specialist depth β these artifacts are released as empirical validation of paper Β§6.5 (dual-head construction + base preservation), not as production chat/math/code/tool models. Stage-tool in particular is format-brittle: function-call emission requires multi-line indented JSON in the system prompt, matching the glaive training distribution. See each stage's README for honest per-domain evaluation and prompt-format guidance.
Files
config.jsonβ HF-compatible model config (includesauto_mapfortrust_remote_code)configuration_cognica_poe.pyβCognicaPoEConfig(PretrainedConfig)subclassmodeling_cognica_poe.pyβCognicaPoEForCausalLMwrapper + inlined nanochat GPT port with KV cachetokenization_cognica_poe.pyβCognicaPoETokenizer(PreTrainedTokenizer)wrapping the tiktoken BPEtokenizer_config.json,special_tokens_map.jsonβ HF tokenizer metadata (class, specials, max length)tokenizer.pklβ pickledtiktoken.Encodingloaded byCognicaPoETokenizertoken_bytes.ptβ per-token byte-length tensor (training-side BPB metric; not required for inference)convert_checkpoint.pyβ converts nanochat.pt+meta.jsonintomodel.safetensorsLICENSE,NOTICEβ Apache 2.0 + upstream MIT attributionmodel.safetensors,generation_config.jsonβ pretrained weights (bf16, 2.6 GB)
Citation
If you use this model or the PoE local-learning methodology, please cite:
@article{jeong2026poe,
title = {Product of Experts as Scalable Local Learning: Modular Construction at 1.3B Parameters},
author = {Jeong, Jaepil},
year = {2026},
institution = {Cognica, Inc.},
doi = {10.5281/zenodo.19547653},
url = {https://doi.org/10.5281/zenodo.19547653}
}
@misc{cognica-poe-v1-2026,
title = {Cognica-PoE-v1.0: A 1.3B Causal LM with Product-of-Experts Local Learning},
author = {{Cognica, Inc.}},
year = {2026},
howpublished = {\url{https://huggingface.co/cognica/Cognica-PoE-v1.0-1.3B-base}}
}
Acknowledgments
- nanochat by Andrej Karpathy β training framework (Muon, ReLUΒ², RoPE, RMSNorm, ResFormer value embeddings,
x_0blending) - NVIDIA ClimbMix (via karpathy/climbmix-400b-shuffle mirror) β pretraining dataset
- Muon optimizer (Keller Jordan et al.) β used via
MuonAdamWhybrid - Hinton (2002), Training Products of Experts by Minimizing Contrastive Divergence β theoretical foundation
License
Apache 2.0 β see LICENSE.
- Downloads last month
- 854