Upload ruler_tsne.py with huggingface_hub
Browse files- ruler_tsne.py +215 -0
ruler_tsne.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""把输入文本 + ruler 200 条样本一起送进 Qwen3-Embedding-8B,做 t-SNE 可视化。
|
| 3 |
+
|
| 4 |
+
用法:
|
| 5 |
+
# 输入文本直接给字符串
|
| 6 |
+
python ruler_tsne.py --input-text "user_0: hi\nuser_1: hello"
|
| 7 |
+
|
| 8 |
+
# 输入文本从文件读
|
| 9 |
+
python ruler_tsne.py --input-text /path/to/conv.txt
|
| 10 |
+
|
| 11 |
+
# 自定义路径 / 输出 / 批大小 / 最大长度
|
| 12 |
+
python ruler_tsne.py \
|
| 13 |
+
--input-text /path/to/conv.txt \
|
| 14 |
+
--ruler /mnt/.../ruler_items.json \
|
| 15 |
+
--model /mnt/.../Qwen3-Embedding-8B \
|
| 16 |
+
--output ruler_tsne.png \
|
| 17 |
+
--max-length 4096 \
|
| 18 |
+
--batch-size 4 \
|
| 19 |
+
--perplexity 20
|
| 20 |
+
|
| 21 |
+
输出:
|
| 22 |
+
- ruler_tsne.png -- 二维散点图,200 个 ruler 点按 score 着色(红=严重→绿=不严重),
|
| 23 |
+
每个点上标 rank 编号;输入文本用红色五角星标 INPUT。
|
| 24 |
+
- 控制台同时打印 top-5 最相似的 ruler items(按 cosine 相似度)。
|
| 25 |
+
"""
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import sys
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from torch import Tensor
|
| 35 |
+
from transformers import AutoTokenizer, AutoModel
|
| 36 |
+
|
| 37 |
+
import matplotlib
|
| 38 |
+
matplotlib.use("Agg")
|
| 39 |
+
import matplotlib.pyplot as plt
|
| 40 |
+
from sklearn.manifold import TSNE
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
DEFAULT_MODEL = "/mnt/bn/tns-algo-ue-my/biaowu/WorkSpace/Models/Qwen3-Embedding-8B"
|
| 44 |
+
DEFAULT_RULER = "/mnt/bn/tns-algo-ue-my/biaowu/aipf_dm_metric/ranking_moderation/data/dm/youth_sexual_and_physical_abuse_aigt_v009/ranking_bucket/ruler_items.json"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 48 |
+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 49 |
+
if left_padding:
|
| 50 |
+
return last_hidden_states[:, -1]
|
| 51 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 52 |
+
bsz = last_hidden_states.shape[0]
|
| 53 |
+
return last_hidden_states[torch.arange(bsz, device=last_hidden_states.device), sequence_lengths]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def encode(texts, tokenizer, model, max_length=4096, batch_size=4):
|
| 58 |
+
embs = []
|
| 59 |
+
for i in range(0, len(texts), batch_size):
|
| 60 |
+
batch = texts[i:i + batch_size]
|
| 61 |
+
d = tokenizer(batch, padding=True, truncation=True,
|
| 62 |
+
max_length=max_length, return_tensors="pt").to(model.device)
|
| 63 |
+
out = model(**d)
|
| 64 |
+
e = last_token_pool(out.last_hidden_state, d["attention_mask"])
|
| 65 |
+
e = F.normalize(e, p=2, dim=1)
|
| 66 |
+
embs.append(e.cpu().float())
|
| 67 |
+
del out, d, e
|
| 68 |
+
if torch.cuda.is_available():
|
| 69 |
+
torch.cuda.empty_cache()
|
| 70 |
+
print(f" encoded {min(i + batch_size, len(texts))}/{len(texts)}", flush=True)
|
| 71 |
+
return torch.cat(embs, dim=0).numpy()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_ruler_items(path: str):
|
| 75 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 76 |
+
data = json.load(f)
|
| 77 |
+
if isinstance(data, list):
|
| 78 |
+
items = data
|
| 79 |
+
else:
|
| 80 |
+
for k in ("items", "ruler_items", "data"):
|
| 81 |
+
if k in data and isinstance(data[k], list):
|
| 82 |
+
items = data[k]
|
| 83 |
+
break
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError("unexpected ruler json structure")
|
| 86 |
+
out = []
|
| 87 |
+
for it in items:
|
| 88 |
+
inner = it.get("item", {}) if isinstance(it.get("item"), dict) else {}
|
| 89 |
+
conv = inner.get("conv_text") or it.get("conv_text") or ""
|
| 90 |
+
out.append({
|
| 91 |
+
"rank": it.get("rank"),
|
| 92 |
+
"score": it.get("score"),
|
| 93 |
+
"item_id": it.get("item_id"),
|
| 94 |
+
"text": conv,
|
| 95 |
+
})
|
| 96 |
+
return out
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def resolve_input(arg: str) -> str:
|
| 100 |
+
if arg == "-":
|
| 101 |
+
return sys.stdin.read().strip()
|
| 102 |
+
p = Path(arg)
|
| 103 |
+
if p.exists() and p.is_file():
|
| 104 |
+
return p.read_text(encoding="utf-8").strip()
|
| 105 |
+
return arg
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def parse_args():
|
| 109 |
+
p = argparse.ArgumentParser()
|
| 110 |
+
p.add_argument("--input-text", required=True,
|
| 111 |
+
help="原始文本字符串、文件路径,或 '-' 表示从 stdin 读")
|
| 112 |
+
p.add_argument("--ruler", default=DEFAULT_RULER)
|
| 113 |
+
p.add_argument("--model", default=DEFAULT_MODEL)
|
| 114 |
+
p.add_argument("--output", default="ruler_tsne.png")
|
| 115 |
+
p.add_argument("--max-length", type=int, default=4096)
|
| 116 |
+
p.add_argument("--batch-size", type=int, default=4)
|
| 117 |
+
p.add_argument("--perplexity", type=float, default=20.0)
|
| 118 |
+
p.add_argument("--label-fontsize", type=float, default=5,
|
| 119 |
+
help="rank 编号的字号,太挤就调小")
|
| 120 |
+
p.add_argument("--cpu", action="store_true", help="强制走 CPU(不推荐,巨慢)")
|
| 121 |
+
p.add_argument("--no-flash-attn", action="store_true",
|
| 122 |
+
help="不用 flash-attn-2(环境没装就加这个)")
|
| 123 |
+
return p.parse_args()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main():
|
| 127 |
+
args = parse_args()
|
| 128 |
+
|
| 129 |
+
# ---- 1) 加载尺子 ----
|
| 130 |
+
print(f"[1/4] 读 ruler: {args.ruler}")
|
| 131 |
+
items = load_ruler_items(args.ruler)
|
| 132 |
+
print(f" -> {len(items)} ruler items")
|
| 133 |
+
|
| 134 |
+
input_text = resolve_input(args.input_text)
|
| 135 |
+
print(f" input text length: {len(input_text)} chars")
|
| 136 |
+
texts = [input_text] + [it["text"] for it in items]
|
| 137 |
+
|
| 138 |
+
# ---- 2) 加载模型 ----
|
| 139 |
+
print(f"[2/4] 加载模型: {args.model}")
|
| 140 |
+
device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 141 |
+
print(f" device: {device}")
|
| 142 |
+
model_kwargs = {}
|
| 143 |
+
if device == "cuda":
|
| 144 |
+
model_kwargs["torch_dtype"] = torch.float16
|
| 145 |
+
if not args.no_flash_attn:
|
| 146 |
+
try:
|
| 147 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 148 |
+
except Exception:
|
| 149 |
+
pass
|
| 150 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="left")
|
| 151 |
+
model = AutoModel.from_pretrained(args.model, **model_kwargs).to(device).eval()
|
| 152 |
+
|
| 153 |
+
# ---- 3) 编码 ----
|
| 154 |
+
print(f"[3/4] 编码 {len(texts)} 条(batch_size={args.batch_size}, max_length={args.max_length})")
|
| 155 |
+
embeddings = encode(texts, tokenizer, model,
|
| 156 |
+
max_length=args.max_length, batch_size=args.batch_size)
|
| 157 |
+
print(f" embeddings: {embeddings.shape}")
|
| 158 |
+
|
| 159 |
+
# ---- 4) t-SNE + 画图 ----
|
| 160 |
+
print(f"[4/4] t-SNE (perplexity={args.perplexity}) + 画图")
|
| 161 |
+
tsne = TSNE(n_components=2, perplexity=args.perplexity,
|
| 162 |
+
init="pca", random_state=42, metric="cosine")
|
| 163 |
+
xy = tsne.fit_transform(embeddings)
|
| 164 |
+
|
| 165 |
+
input_xy = xy[0]
|
| 166 |
+
ruler_xy = xy[1:]
|
| 167 |
+
ranks = np.array([it["rank"] for it in items])
|
| 168 |
+
scores = np.array([it["score"] for it in items], dtype=float)
|
| 169 |
+
|
| 170 |
+
fig, ax = plt.subplots(figsize=(14, 12), dpi=130)
|
| 171 |
+
sc = ax.scatter(
|
| 172 |
+
ruler_xy[:, 0], ruler_xy[:, 1],
|
| 173 |
+
c=scores, cmap="RdYlGn_r",
|
| 174 |
+
s=45, alpha=0.85,
|
| 175 |
+
edgecolor="black", linewidth=0.3,
|
| 176 |
+
)
|
| 177 |
+
cbar = plt.colorbar(sc, ax=ax, shrink=0.8)
|
| 178 |
+
cbar.set_label("ruler score (high = more severe)")
|
| 179 |
+
|
| 180 |
+
# 标 rank 编号
|
| 181 |
+
for (x, y), r in zip(ruler_xy, ranks):
|
| 182 |
+
ax.annotate(str(r), (x, y),
|
| 183 |
+
fontsize=args.label_fontsize,
|
| 184 |
+
ha="center", va="center",
|
| 185 |
+
alpha=0.85)
|
| 186 |
+
|
| 187 |
+
# 输入点
|
| 188 |
+
ax.scatter([input_xy[0]], [input_xy[1]],
|
| 189 |
+
marker="*", s=750, c="red",
|
| 190 |
+
edgecolor="black", linewidth=1.5,
|
| 191 |
+
zorder=10, label="INPUT")
|
| 192 |
+
ax.annotate("INPUT", input_xy,
|
| 193 |
+
fontsize=12, fontweight="bold", color="red",
|
| 194 |
+
xytext=(10, 10), textcoords="offset points")
|
| 195 |
+
|
| 196 |
+
ax.set_title("t-SNE: input + 200 ruler items (Qwen3-Embedding-8B)")
|
| 197 |
+
ax.set_xlabel("t-SNE 1")
|
| 198 |
+
ax.set_ylabel("t-SNE 2")
|
| 199 |
+
ax.legend(loc="best")
|
| 200 |
+
plt.tight_layout()
|
| 201 |
+
plt.savefig(args.output, dpi=130, bbox_inches="tight")
|
| 202 |
+
print(f" saved: {args.output}")
|
| 203 |
+
|
| 204 |
+
# ---- 5) Top-5 最相似 ruler items ----
|
| 205 |
+
sims = embeddings[1:] @ embeddings[0] # cosine since L2 normalized
|
| 206 |
+
top5 = np.argsort(-sims)[:5]
|
| 207 |
+
print("\nTop-5 nearest ruler items by cosine similarity:")
|
| 208 |
+
for idx in top5:
|
| 209 |
+
it = items[idx]
|
| 210 |
+
print(f" rank={it['rank']:>3} score={it['score']:.2f} "
|
| 211 |
+
f"sim={sims[idx]:.4f} id={it['item_id']}")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
main()
|