File size: 1,509 Bytes
aedd6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""CWM trace special tokens + tokenizer/embedding setup for a non-CWM base."""

# Trace-format tokens (mirrors data/trace_format.py) + latent delimiters.
TRACE_TOKENS = [
    "<|trace_context_start|>",
    "<|call_sep|>", "<|line_sep|>", "<|return_sep|>", "<|exception_sep|>",
    "<|action_sep|>", "<|arg_sep|>", "<|frame_sep|>", "<|end_of_text|>",
    "<|latent_start|>", "<|latent_end|>",
]


def add_trace_tokens(tokenizer) -> int:
    """Add the trace tokens as special tokens. Returns the count newly added."""
    return tokenizer.add_tokens(TRACE_TOKENS, special_tokens=True)


def resize_and_init(model, tokenizer, n_added: int) -> None:
    """Resize embeddings to the tokenizer; init new rows to the existing mean."""
    old = model.get_input_embeddings().weight.shape[0]
    model.resize_token_embeddings(len(tokenizer))
    if n_added <= 0:
        return
    seen = set()
    for emb in (model.get_input_embeddings(), model.get_output_embeddings()):
        if emb is None or id(emb) in seen:  # tied embeddings: resize once
            continue
        seen.add(id(emb))
        w = emb.weight.data
        w[old:] = w[:old].mean(dim=0, keepdim=True)


def token_ids(tokenizer) -> dict[str, int]:
    """Map each trace token to its single id (asserts single-token encoding)."""
    ids = {}
    for t in TRACE_TOKENS:
        enc = tokenizer.encode(t, add_special_tokens=False)
        assert len(enc) == 1, f"{t!r} did not encode to a single id: {enc}"
        ids[t] = enc[0]
    return ids