Spaces:
Running
Running
| import json | |
| import os | |
| import struct | |
| from typing import Dict, List | |
| import pandas as pd | |
| import requests | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # Required metrics for embedding evaluation | |
| REQUIRED_METRICS = [ | |
| "mteb_avg", | |
| "sts_spearman", | |
| "retrieval_top20", | |
| "msmarco_top10", | |
| ] | |
| def format_params(num_params): | |
| """Format parameter count as human-readable string.""" | |
| if num_params >= 1e9: | |
| return f"{num_params / 1e9:.1f}B" | |
| else: | |
| return f"{num_params / 1e6:.0f}M" | |
| def get_model_url(model_name): | |
| """Get the model URL from HuggingFace.""" | |
| return f"https://huggingface.co/{model_name}" | |
| def get_model_size(model_name): | |
| """Fetch model size from HuggingFace API.""" | |
| try: | |
| url = f"https://huggingface.co/api/models/{model_name}" | |
| response = requests.get(url, timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| # Get safetensors size first, fallback to general parameters | |
| safetensors = data.get("safetensors") | |
| if safetensors and "total" in safetensors: | |
| num_params = safetensors["total"] | |
| return format_params(num_params) | |
| num_params = data.get("num_parameters") | |
| if num_params: | |
| return format_params(num_params) | |
| # Fallback: read actual param count from safetensors header | |
| num_params = get_params_from_safetensors(model_name) | |
| if num_params: | |
| return format_params(num_params) | |
| return None | |
| except Exception as e: | |
| print(f"Error fetching size for {model_name}: {e}") | |
| return None | |
| def get_params_from_safetensors(model_name): | |
| """Read safetensors header to get actual parameter count.""" | |
| try: | |
| tree_url = f"https://huggingface.co/api/models/{model_name}/tree/main" | |
| resp = requests.get(tree_url, timeout=10) | |
| if resp.status_code != 200: | |
| return None | |
| files = resp.json() | |
| safetensor_files = [f for f in files if f.get("path", "").endswith(".safetensors")] | |
| if not safetensor_files: | |
| return None | |
| total_params = 0 | |
| for sf in safetensor_files: | |
| file_url = f"https://huggingface.co/{model_name}/resolve/main/{sf['path']}" | |
| # Get header size (first 8 bytes) | |
| headers = {"Range": "bytes=0-7"} | |
| resp = requests.get(file_url, headers=headers, timeout=10, allow_redirects=True) | |
| if resp.status_code != 206 or len(resp.content) < 8: | |
| return None # Likely gated model | |
| header_size = struct.unpack("<Q", resp.content[:8])[0] | |
| # Get header JSON | |
| headers = {"Range": f"bytes=8-{8 + header_size - 1}"} | |
| resp = requests.get(file_url, headers=headers, timeout=10, allow_redirects=True) | |
| metadata = resp.json() | |
| # Calculate params from tensor shapes | |
| for key, info in metadata.items(): | |
| if key == "__metadata__": | |
| continue | |
| shape = info.get("shape", []) | |
| params = 1 | |
| for dim in shape: | |
| params *= dim | |
| total_params += params | |
| return total_params | |
| except Exception: | |
| return None | |
| class ModelHandler: | |
| def __init__(self, model_infos_path="model_results.json"): | |
| self.api = HfApi() | |
| self.model_infos_path = model_infos_path | |
| self.model_infos = self._load_model_infos() | |
| def _load_model_infos(self) -> List: | |
| if os.path.exists(self.model_infos_path): | |
| with open(self.model_infos_path) as f: | |
| return json.load(f) | |
| return [] | |
| def _save_model_infos(self): | |
| print("Saving model infos") | |
| with open(self.model_infos_path, "w") as f: | |
| json.dump(self.model_infos, f, indent=4) | |
| def get_embedding_benchmark_data(self) -> pd.DataFrame: | |
| """Fetch embedding benchmark results from HuggingFace models with ArmBench-TextEmbed tag.""" | |
| # Try to fetch new models from HuggingFace, but gracefully handle network errors | |
| try: | |
| models = self.api.list_models(filter="ArmBench-TextEmbed") | |
| model_names = {model["model_name"] for model in self.model_infos} | |
| repositories = [model.modelId for model in models] | |
| for repo_id in repositories: | |
| try: | |
| files = [f for f in self.api.list_repo_files(repo_id) if f == "results.json"] | |
| if not files: | |
| continue | |
| model_name = repo_id | |
| if model_name not in model_names: | |
| result_path = hf_hub_download(repo_id, filename="results.json") | |
| with open(result_path) as f: | |
| results = json.load(f) | |
| # Build model entry with metadata | |
| entry = { | |
| "model_name": model_name, | |
| "results": results | |
| } | |
| # Add model_url if not in results | |
| if "model_url" not in results: | |
| entry["model_url"] = get_model_url(model_name) | |
| # Add model_size if not in results | |
| if "model_size" not in results: | |
| model_size = get_model_size(model_name) | |
| if model_size: | |
| entry["model_size"] = model_size | |
| self.model_infos.append(entry) | |
| except Exception as e: | |
| print(f"Error loading {repo_id} - {e}") | |
| continue | |
| self._save_model_infos() | |
| except Exception as e: | |
| print(f"Failed to fetch from HuggingFace: {e}. Using local data.") | |
| # Build dataframe from results | |
| data = [] | |
| for model in self.model_infos: | |
| model_name = model["model_name"] | |
| results = model.get("results", {}) | |
| row = {"model_name": model_name} | |
| # Extract model metadata | |
| if "model_url" in model: | |
| row["model_url"] = model["model_url"] | |
| if "model_size" in model: | |
| row["model_size"] = model["model_size"] | |
| # Extract key metrics | |
| if "mteb_avg" in results: | |
| row["mteb_avg"] = results["mteb_avg"] | |
| if "sts_spearman" in results: | |
| row["sts_spearman"] = results["sts_spearman"] | |
| if "retrieval_top20" in results: | |
| row["retrieval_top20"] = results["retrieval_top20"] | |
| if "retrieval_translit_top20" in results: | |
| row["retrieval_translit_top20"] = results["retrieval_translit_top20"] | |
| if "msmarco_top10" in results: | |
| row["msmarco_top10"] = results["msmarco_top10"] | |
| if "msmarco_translit_top10" in results: | |
| row["msmarco_translit_top10"] = results["msmarco_translit_top10"] | |
| # Only add if at least one metric is present | |
| if len(row) > 1: | |
| data.append(row) | |
| return pd.DataFrame(data) | |
| def get_detailed_results(self) -> Dict: | |
| """Get all detailed results for MTEB, MS MARCO, STS, Retrieval, and translit benchmarks.""" | |
| mteb_data = [] | |
| msmarco_data = [] | |
| sts_data = [] | |
| retrieval_data = [] | |
| retrieval_translit_data = [] | |
| msmarco_translit_data = [] | |
| for model in self.model_infos: | |
| model_name = model["model_name"] | |
| results = model.get("results", {}) | |
| # MTEB detailed | |
| if "mteb_detailed" in results: | |
| row = {"model_name": model_name, **results["mteb_detailed"]} | |
| mteb_data.append(row) | |
| # MS MARCO detailed | |
| if "msmarco_detailed" in results: | |
| row = {"model_name": model_name, **results["msmarco_detailed"]} | |
| msmarco_data.append(row) | |
| # STS detailed | |
| if "sts_detailed" in results: | |
| row = {"model_name": model_name, **results["sts_detailed"]} | |
| sts_data.append(row) | |
| # Retrieval detailed | |
| if "retrieval_detailed" in results: | |
| row = {"model_name": model_name, **results["retrieval_detailed"]} | |
| retrieval_data.append(row) | |
| # Retrieval translit detailed | |
| if "retrieval_translit_detailed" in results: | |
| row = {"model_name": model_name, **results["retrieval_translit_detailed"]} | |
| retrieval_translit_data.append(row) | |
| # MS MARCO translit detailed | |
| if "msmarco_translit_detailed" in results: | |
| row = {"model_name": model_name, **results["msmarco_translit_detailed"]} | |
| msmarco_translit_data.append(row) | |
| return { | |
| "mteb": pd.DataFrame(mteb_data) if mteb_data else pd.DataFrame(), | |
| "msmarco": pd.DataFrame(msmarco_data) if msmarco_data else pd.DataFrame(), | |
| "sts": pd.DataFrame(sts_data) if sts_data else pd.DataFrame(), | |
| "retrieval": pd.DataFrame(retrieval_data) if retrieval_data else pd.DataFrame(), | |
| "retrieval_translit": pd.DataFrame(retrieval_translit_data) if retrieval_translit_data else pd.DataFrame(), | |
| "msmarco_translit": pd.DataFrame(msmarco_translit_data) if msmarco_translit_data else pd.DataFrame(), | |
| } | |