Wendy-Fly commited on
Commit
37c2c5b
·
verified ·
1 Parent(s): a7d7258

Upload ruler_tsne.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ruler_tsne.py +215 -0
ruler_tsne.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """把输入文本 + ruler 200 条样本一起送进 Qwen3-Embedding-8B,做 t-SNE 可视化。
3
+
4
+ 用法:
5
+ # 输入文本直接给字符串
6
+ python ruler_tsne.py --input-text "user_0: hi\nuser_1: hello"
7
+
8
+ # 输入文本从文件读
9
+ python ruler_tsne.py --input-text /path/to/conv.txt
10
+
11
+ # 自定义路径 / 输出 / 批大小 / 最大长度
12
+ python ruler_tsne.py \
13
+ --input-text /path/to/conv.txt \
14
+ --ruler /mnt/.../ruler_items.json \
15
+ --model /mnt/.../Qwen3-Embedding-8B \
16
+ --output ruler_tsne.png \
17
+ --max-length 4096 \
18
+ --batch-size 4 \
19
+ --perplexity 20
20
+
21
+ 输出:
22
+ - ruler_tsne.png -- 二维散点图,200 个 ruler 点按 score 着色(红=严重→绿=不严重),
23
+ 每个点上标 rank 编号;输入文本用红色五角星标 INPUT。
24
+ - 控制台同时打印 top-5 最相似的 ruler items(按 cosine 相似度)。
25
+ """
26
+ import argparse
27
+ import json
28
+ import sys
29
+ from pathlib import Path
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from torch import Tensor
35
+ from transformers import AutoTokenizer, AutoModel
36
+
37
+ import matplotlib
38
+ matplotlib.use("Agg")
39
+ import matplotlib.pyplot as plt
40
+ from sklearn.manifold import TSNE
41
+
42
+
43
+ DEFAULT_MODEL = "/mnt/bn/tns-algo-ue-my/biaowu/WorkSpace/Models/Qwen3-Embedding-8B"
44
+ DEFAULT_RULER = "/mnt/bn/tns-algo-ue-my/biaowu/aipf_dm_metric/ranking_moderation/data/dm/youth_sexual_and_physical_abuse_aigt_v009/ranking_bucket/ruler_items.json"
45
+
46
+
47
+ def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
48
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
49
+ if left_padding:
50
+ return last_hidden_states[:, -1]
51
+ sequence_lengths = attention_mask.sum(dim=1) - 1
52
+ bsz = last_hidden_states.shape[0]
53
+ return last_hidden_states[torch.arange(bsz, device=last_hidden_states.device), sequence_lengths]
54
+
55
+
56
+ @torch.no_grad()
57
+ def encode(texts, tokenizer, model, max_length=4096, batch_size=4):
58
+ embs = []
59
+ for i in range(0, len(texts), batch_size):
60
+ batch = texts[i:i + batch_size]
61
+ d = tokenizer(batch, padding=True, truncation=True,
62
+ max_length=max_length, return_tensors="pt").to(model.device)
63
+ out = model(**d)
64
+ e = last_token_pool(out.last_hidden_state, d["attention_mask"])
65
+ e = F.normalize(e, p=2, dim=1)
66
+ embs.append(e.cpu().float())
67
+ del out, d, e
68
+ if torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+ print(f" encoded {min(i + batch_size, len(texts))}/{len(texts)}", flush=True)
71
+ return torch.cat(embs, dim=0).numpy()
72
+
73
+
74
+ def load_ruler_items(path: str):
75
+ with open(path, "r", encoding="utf-8") as f:
76
+ data = json.load(f)
77
+ if isinstance(data, list):
78
+ items = data
79
+ else:
80
+ for k in ("items", "ruler_items", "data"):
81
+ if k in data and isinstance(data[k], list):
82
+ items = data[k]
83
+ break
84
+ else:
85
+ raise ValueError("unexpected ruler json structure")
86
+ out = []
87
+ for it in items:
88
+ inner = it.get("item", {}) if isinstance(it.get("item"), dict) else {}
89
+ conv = inner.get("conv_text") or it.get("conv_text") or ""
90
+ out.append({
91
+ "rank": it.get("rank"),
92
+ "score": it.get("score"),
93
+ "item_id": it.get("item_id"),
94
+ "text": conv,
95
+ })
96
+ return out
97
+
98
+
99
+ def resolve_input(arg: str) -> str:
100
+ if arg == "-":
101
+ return sys.stdin.read().strip()
102
+ p = Path(arg)
103
+ if p.exists() and p.is_file():
104
+ return p.read_text(encoding="utf-8").strip()
105
+ return arg
106
+
107
+
108
+ def parse_args():
109
+ p = argparse.ArgumentParser()
110
+ p.add_argument("--input-text", required=True,
111
+ help="原始文本字符串、文件路径,或 '-' 表示从 stdin 读")
112
+ p.add_argument("--ruler", default=DEFAULT_RULER)
113
+ p.add_argument("--model", default=DEFAULT_MODEL)
114
+ p.add_argument("--output", default="ruler_tsne.png")
115
+ p.add_argument("--max-length", type=int, default=4096)
116
+ p.add_argument("--batch-size", type=int, default=4)
117
+ p.add_argument("--perplexity", type=float, default=20.0)
118
+ p.add_argument("--label-fontsize", type=float, default=5,
119
+ help="rank 编号的字号,太挤就调小")
120
+ p.add_argument("--cpu", action="store_true", help="强制走 CPU(不推荐,巨慢)")
121
+ p.add_argument("--no-flash-attn", action="store_true",
122
+ help="不用 flash-attn-2(环境没装就加这个)")
123
+ return p.parse_args()
124
+
125
+
126
+ def main():
127
+ args = parse_args()
128
+
129
+ # ---- 1) 加载尺子 ----
130
+ print(f"[1/4] 读 ruler: {args.ruler}")
131
+ items = load_ruler_items(args.ruler)
132
+ print(f" -> {len(items)} ruler items")
133
+
134
+ input_text = resolve_input(args.input_text)
135
+ print(f" input text length: {len(input_text)} chars")
136
+ texts = [input_text] + [it["text"] for it in items]
137
+
138
+ # ---- 2) 加载模型 ----
139
+ print(f"[2/4] 加载模型: {args.model}")
140
+ device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")
141
+ print(f" device: {device}")
142
+ model_kwargs = {}
143
+ if device == "cuda":
144
+ model_kwargs["torch_dtype"] = torch.float16
145
+ if not args.no_flash_attn:
146
+ try:
147
+ model_kwargs["attn_implementation"] = "flash_attention_2"
148
+ except Exception:
149
+ pass
150
+ tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left")
151
+ model = AutoModel.from_pretrained(args.model, **model_kwargs).to(device).eval()
152
+
153
+ # ---- 3) 编码 ----
154
+ print(f"[3/4] 编码 {len(texts)} 条(batch_size={args.batch_size}, max_length={args.max_length})")
155
+ embeddings = encode(texts, tokenizer, model,
156
+ max_length=args.max_length, batch_size=args.batch_size)
157
+ print(f" embeddings: {embeddings.shape}")
158
+
159
+ # ---- 4) t-SNE + 画图 ----
160
+ print(f"[4/4] t-SNE (perplexity={args.perplexity}) + 画图")
161
+ tsne = TSNE(n_components=2, perplexity=args.perplexity,
162
+ init="pca", random_state=42, metric="cosine")
163
+ xy = tsne.fit_transform(embeddings)
164
+
165
+ input_xy = xy[0]
166
+ ruler_xy = xy[1:]
167
+ ranks = np.array([it["rank"] for it in items])
168
+ scores = np.array([it["score"] for it in items], dtype=float)
169
+
170
+ fig, ax = plt.subplots(figsize=(14, 12), dpi=130)
171
+ sc = ax.scatter(
172
+ ruler_xy[:, 0], ruler_xy[:, 1],
173
+ c=scores, cmap="RdYlGn_r",
174
+ s=45, alpha=0.85,
175
+ edgecolor="black", linewidth=0.3,
176
+ )
177
+ cbar = plt.colorbar(sc, ax=ax, shrink=0.8)
178
+ cbar.set_label("ruler score (high = more severe)")
179
+
180
+ # 标 rank 编号
181
+ for (x, y), r in zip(ruler_xy, ranks):
182
+ ax.annotate(str(r), (x, y),
183
+ fontsize=args.label_fontsize,
184
+ ha="center", va="center",
185
+ alpha=0.85)
186
+
187
+ # 输入点
188
+ ax.scatter([input_xy[0]], [input_xy[1]],
189
+ marker="*", s=750, c="red",
190
+ edgecolor="black", linewidth=1.5,
191
+ zorder=10, label="INPUT")
192
+ ax.annotate("INPUT", input_xy,
193
+ fontsize=12, fontweight="bold", color="red",
194
+ xytext=(10, 10), textcoords="offset points")
195
+
196
+ ax.set_title("t-SNE: input + 200 ruler items (Qwen3-Embedding-8B)")
197
+ ax.set_xlabel("t-SNE 1")
198
+ ax.set_ylabel("t-SNE 2")
199
+ ax.legend(loc="best")
200
+ plt.tight_layout()
201
+ plt.savefig(args.output, dpi=130, bbox_inches="tight")
202
+ print(f" saved: {args.output}")
203
+
204
+ # ---- 5) Top-5 最相似 ruler items ----
205
+ sims = embeddings[1:] @ embeddings[0] # cosine since L2 normalized
206
+ top5 = np.argsort(-sims)[:5]
207
+ print("\nTop-5 nearest ruler items by cosine similarity:")
208
+ for idx in top5:
209
+ it = items[idx]
210
+ print(f" rank={it['rank']:>3} score={it['score']:.2f} "
211
+ f"sim={sims[idx]:.4f} id={it['item_id']}")
212
+
213
+
214
+ if __name__ == "__main__":
215
+ main()