ArmBench-TextEmbed / model_handler.py
Zaruhi's picture
Initial release
c5f9df5
import json
import os
import struct
from typing import Dict, List
import pandas as pd
import requests
from huggingface_hub import HfApi, hf_hub_download
# Required metrics for embedding evaluation
REQUIRED_METRICS = [
"mteb_avg",
"sts_spearman",
"retrieval_top20",
"msmarco_top10",
]
def format_params(num_params):
"""Format parameter count as human-readable string."""
if num_params >= 1e9:
return f"{num_params / 1e9:.1f}B"
else:
return f"{num_params / 1e6:.0f}M"
def get_model_url(model_name):
"""Get the model URL from HuggingFace."""
return f"https://huggingface.co/{model_name}"
def get_model_size(model_name):
"""Fetch model size from HuggingFace API."""
try:
url = f"https://huggingface.co/api/models/{model_name}"
response = requests.get(url, timeout=10)
if response.status_code == 200:
data = response.json()
# Get safetensors size first, fallback to general parameters
safetensors = data.get("safetensors")
if safetensors and "total" in safetensors:
num_params = safetensors["total"]
return format_params(num_params)
num_params = data.get("num_parameters")
if num_params:
return format_params(num_params)
# Fallback: read actual param count from safetensors header
num_params = get_params_from_safetensors(model_name)
if num_params:
return format_params(num_params)
return None
except Exception as e:
print(f"Error fetching size for {model_name}: {e}")
return None
def get_params_from_safetensors(model_name):
"""Read safetensors header to get actual parameter count."""
try:
tree_url = f"https://huggingface.co/api/models/{model_name}/tree/main"
resp = requests.get(tree_url, timeout=10)
if resp.status_code != 200:
return None
files = resp.json()
safetensor_files = [f for f in files if f.get("path", "").endswith(".safetensors")]
if not safetensor_files:
return None
total_params = 0
for sf in safetensor_files:
file_url = f"https://huggingface.co/{model_name}/resolve/main/{sf['path']}"
# Get header size (first 8 bytes)
headers = {"Range": "bytes=0-7"}
resp = requests.get(file_url, headers=headers, timeout=10, allow_redirects=True)
if resp.status_code != 206 or len(resp.content) < 8:
return None # Likely gated model
header_size = struct.unpack("<Q", resp.content[:8])[0]
# Get header JSON
headers = {"Range": f"bytes=8-{8 + header_size - 1}"}
resp = requests.get(file_url, headers=headers, timeout=10, allow_redirects=True)
metadata = resp.json()
# Calculate params from tensor shapes
for key, info in metadata.items():
if key == "__metadata__":
continue
shape = info.get("shape", [])
params = 1
for dim in shape:
params *= dim
total_params += params
return total_params
except Exception:
return None
class ModelHandler:
def __init__(self, model_infos_path="model_results.json"):
self.api = HfApi()
self.model_infos_path = model_infos_path
self.model_infos = self._load_model_infos()
def _load_model_infos(self) -> List:
if os.path.exists(self.model_infos_path):
with open(self.model_infos_path) as f:
return json.load(f)
return []
def _save_model_infos(self):
print("Saving model infos")
with open(self.model_infos_path, "w") as f:
json.dump(self.model_infos, f, indent=4)
def get_embedding_benchmark_data(self) -> pd.DataFrame:
"""Fetch embedding benchmark results from HuggingFace models with ArmBench-TextEmbed tag."""
# Try to fetch new models from HuggingFace, but gracefully handle network errors
try:
models = self.api.list_models(filter="ArmBench-TextEmbed")
model_names = {model["model_name"] for model in self.model_infos}
repositories = [model.modelId for model in models]
for repo_id in repositories:
try:
files = [f for f in self.api.list_repo_files(repo_id) if f == "results.json"]
if not files:
continue
model_name = repo_id
if model_name not in model_names:
result_path = hf_hub_download(repo_id, filename="results.json")
with open(result_path) as f:
results = json.load(f)
# Build model entry with metadata
entry = {
"model_name": model_name,
"results": results
}
# Add model_url if not in results
if "model_url" not in results:
entry["model_url"] = get_model_url(model_name)
# Add model_size if not in results
if "model_size" not in results:
model_size = get_model_size(model_name)
if model_size:
entry["model_size"] = model_size
self.model_infos.append(entry)
except Exception as e:
print(f"Error loading {repo_id} - {e}")
continue
self._save_model_infos()
except Exception as e:
print(f"Failed to fetch from HuggingFace: {e}. Using local data.")
# Build dataframe from results
data = []
for model in self.model_infos:
model_name = model["model_name"]
results = model.get("results", {})
row = {"model_name": model_name}
# Extract model metadata
if "model_url" in model:
row["model_url"] = model["model_url"]
if "model_size" in model:
row["model_size"] = model["model_size"]
# Extract key metrics
if "mteb_avg" in results:
row["mteb_avg"] = results["mteb_avg"]
if "sts_spearman" in results:
row["sts_spearman"] = results["sts_spearman"]
if "retrieval_top20" in results:
row["retrieval_top20"] = results["retrieval_top20"]
if "retrieval_translit_top20" in results:
row["retrieval_translit_top20"] = results["retrieval_translit_top20"]
if "msmarco_top10" in results:
row["msmarco_top10"] = results["msmarco_top10"]
if "msmarco_translit_top10" in results:
row["msmarco_translit_top10"] = results["msmarco_translit_top10"]
# Only add if at least one metric is present
if len(row) > 1:
data.append(row)
return pd.DataFrame(data)
def get_detailed_results(self) -> Dict:
"""Get all detailed results for MTEB, MS MARCO, STS, Retrieval, and translit benchmarks."""
mteb_data = []
msmarco_data = []
sts_data = []
retrieval_data = []
retrieval_translit_data = []
msmarco_translit_data = []
for model in self.model_infos:
model_name = model["model_name"]
results = model.get("results", {})
# MTEB detailed
if "mteb_detailed" in results:
row = {"model_name": model_name, **results["mteb_detailed"]}
mteb_data.append(row)
# MS MARCO detailed
if "msmarco_detailed" in results:
row = {"model_name": model_name, **results["msmarco_detailed"]}
msmarco_data.append(row)
# STS detailed
if "sts_detailed" in results:
row = {"model_name": model_name, **results["sts_detailed"]}
sts_data.append(row)
# Retrieval detailed
if "retrieval_detailed" in results:
row = {"model_name": model_name, **results["retrieval_detailed"]}
retrieval_data.append(row)
# Retrieval translit detailed
if "retrieval_translit_detailed" in results:
row = {"model_name": model_name, **results["retrieval_translit_detailed"]}
retrieval_translit_data.append(row)
# MS MARCO translit detailed
if "msmarco_translit_detailed" in results:
row = {"model_name": model_name, **results["msmarco_translit_detailed"]}
msmarco_translit_data.append(row)
return {
"mteb": pd.DataFrame(mteb_data) if mteb_data else pd.DataFrame(),
"msmarco": pd.DataFrame(msmarco_data) if msmarco_data else pd.DataFrame(),
"sts": pd.DataFrame(sts_data) if sts_data else pd.DataFrame(),
"retrieval": pd.DataFrame(retrieval_data) if retrieval_data else pd.DataFrame(),
"retrieval_translit": pd.DataFrame(retrieval_translit_data) if retrieval_translit_data else pd.DataFrame(),
"msmarco_translit": pd.DataFrame(msmarco_translit_data) if msmarco_translit_data else pd.DataFrame(),
}