Feature Extraction
Transformers
Safetensors
English
bert
retrieval
constbert
colbert
multi-vector
embedding
custom_code
text-embeddings-inference
Instructions to use pinecone/ConstBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use pinecone/ConstBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="pinecone/ConstBERT", trust_remote_code=True)# Load model directly from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("pinecone/ConstBERT", trust_remote_code=True) model = AutoModel.from_pretrained("pinecone/ConstBERT", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
| from dataclasses import dataclass | |
| import __main__ | |
| import os | |
| import ujson | |
| from huggingface_hub import hf_hub_download | |
| import dataclasses | |
| import datetime | |
| from typing import Any | |
| from dataclasses import dataclass, fields | |
| import socket | |
| import git | |
| import time | |
| import torch | |
| import sys | |
| def torch_load_dnn(path): | |
| if path.startswith("http:") or path.startswith("https:"): | |
| dnn = torch.hub.load_state_dict_from_url(path, map_location='cpu') | |
| else: | |
| dnn = torch.load(path, map_location='cpu') | |
| return dnn | |
| class dotdict(dict): | |
| """ | |
| dot.notation access to dictionary attributes | |
| Credit: derek73 @ https://stackoverflow.com/questions/2352181 | |
| """ | |
| __getattr__ = dict.__getitem__ | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| def get_metadata_only(): | |
| args = dotdict() | |
| args.hostname = socket.gethostname() | |
| try: | |
| args.git_branch = git.Repo(search_parent_directories=True).active_branch.name | |
| args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha | |
| args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) | |
| except git.exc.InvalidGitRepositoryError as e: | |
| pass | |
| args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') | |
| args.cmd = ' '.join(sys.argv) | |
| return args | |
| def timestamp(daydir=False): | |
| format_str = f"%Y-%m{'/' if daydir else '-'}%d{'/' if daydir else '_'}%H.%M.%S" | |
| result = datetime.datetime.now().strftime(format_str) | |
| return result | |
| class DefaultVal: | |
| val: Any | |
| def __hash__(self): | |
| return hash(repr(self.val)) | |
| def __eq__(self, other): | |
| self.val == other.val | |
| class RunSettings: | |
| """ | |
| The defaults here have a special status in Run(), which initially calls assign_defaults(), | |
| so these aren't soft defaults in that specific context. | |
| """ | |
| overwrite: bool = DefaultVal(False) | |
| root: str = DefaultVal(os.path.join(os.getcwd(), 'experiments')) | |
| experiment: str = DefaultVal('default') | |
| index_root: str = DefaultVal(None) | |
| name: str = DefaultVal(timestamp(daydir=True)) | |
| rank: int = DefaultVal(0) | |
| nranks: int = DefaultVal(1) | |
| amp: bool = DefaultVal(True) | |
| total_visible_gpus = torch.cuda.device_count() | |
| gpus: int = DefaultVal(total_visible_gpus) | |
| avoid_fork_if_possible: bool = DefaultVal(False) | |
| def gpus_(self): | |
| value = self.gpus | |
| if isinstance(value, int): | |
| value = list(range(value)) | |
| if isinstance(value, str): | |
| value = value.split(',') | |
| value = list(map(int, value)) | |
| value = sorted(list(set(value))) | |
| assert all(device_idx in range(0, self.total_visible_gpus) for device_idx in value), value | |
| return value | |
| def index_root_(self): | |
| return self.index_root or os.path.join(self.root, self.experiment, 'indexes/') | |
| def script_name_(self): | |
| if '__file__' in dir(__main__): | |
| cwd = os.path.abspath(os.getcwd()) | |
| script_path = os.path.abspath(__main__.__file__) | |
| root_path = os.path.abspath(self.root) | |
| if script_path.startswith(cwd): | |
| script_path = script_path[len(cwd):] | |
| else: | |
| try: | |
| commonpath = os.path.commonpath([script_path, root_path]) | |
| script_path = script_path[len(commonpath):] | |
| except: | |
| pass | |
| assert script_path.endswith('.py') | |
| script_name = script_path.replace('/', '.').strip('.')[:-3] | |
| assert len(script_name) > 0, (script_name, script_path, cwd) | |
| return script_name | |
| return 'none' | |
| def path_(self): | |
| return os.path.join(self.root, self.experiment, self.script_name_, self.name) | |
| def device_(self): | |
| return self.gpus_[self.rank % self.nranks] | |
| class TokenizerSettings: | |
| query_token_id: str = DefaultVal("[unused0]") | |
| doc_token_id: str = DefaultVal("[unused1]") | |
| query_token: str = DefaultVal("[Q]") | |
| doc_token: str = DefaultVal("[D]") | |
| class ResourceSettings: | |
| checkpoint: str = DefaultVal(None) | |
| triples: str = DefaultVal(None) | |
| collection: str = DefaultVal(None) | |
| queries: str = DefaultVal(None) | |
| index_name: str = DefaultVal(None) | |
| name_or_path: str = DefaultVal(None) | |
| class DocSettings: | |
| dim: int = DefaultVal(128) | |
| doc_maxlen: int = DefaultVal(220) | |
| mask_punctuation: bool = DefaultVal(True) | |
| class QuerySettings: | |
| query_maxlen: int = DefaultVal(32) | |
| attend_to_mask_tokens : bool = DefaultVal(False) | |
| interaction: str = DefaultVal('colbert') | |
| class TrainingSettings: | |
| similarity: str = DefaultVal('cosine') | |
| bsize: int = DefaultVal(32) | |
| accumsteps: int = DefaultVal(1) | |
| lr: float = DefaultVal(3e-06) | |
| maxsteps: int = DefaultVal(500_000) | |
| save_every: int = DefaultVal(None) | |
| resume: bool = DefaultVal(False) | |
| ## NEW: | |
| warmup: int = DefaultVal(None) | |
| warmup_bert: int = DefaultVal(None) | |
| relu: bool = DefaultVal(False) | |
| nway: int = DefaultVal(2) | |
| use_ib_negatives: bool = DefaultVal(False) | |
| reranker: bool = DefaultVal(False) | |
| distillation_alpha: float = DefaultVal(1.0) | |
| ignore_scores: bool = DefaultVal(False) | |
| model_name: str = DefaultVal(None) # DefaultVal('bert-base-uncased') | |
| class IndexingSettings: | |
| index_path: str = DefaultVal(None) | |
| index_bsize: int = DefaultVal(64) | |
| nbits: int = DefaultVal(1) | |
| kmeans_niters: int = DefaultVal(4) | |
| resume: bool = DefaultVal(False) | |
| def index_path_(self): | |
| return self.index_path or os.path.join(self.index_root_, self.index_name) | |
| class SearchSettings: | |
| ncells: int = DefaultVal(None) | |
| centroid_score_threshold: float = DefaultVal(None) | |
| ndocs: int = DefaultVal(None) | |
| load_index_with_mmap: bool = DefaultVal(False) | |
| class CoreConfig: | |
| def __post_init__(self): | |
| """ | |
| Source: https://stackoverflow.com/a/58081120/1493011 | |
| """ | |
| self.assigned = {} | |
| for field in fields(self): | |
| field_val = getattr(self, field.name) | |
| if isinstance(field_val, DefaultVal) or field_val is None: | |
| setattr(self, field.name, field.default.val) | |
| if not isinstance(field_val, DefaultVal): | |
| self.assigned[field.name] = True | |
| def assign_defaults(self): | |
| for field in fields(self): | |
| setattr(self, field.name, field.default.val) | |
| self.assigned[field.name] = True | |
| def configure(self, ignore_unrecognized=True, **kw_args): | |
| ignored = set() | |
| for key, value in kw_args.items(): | |
| self.set(key, value, ignore_unrecognized) or ignored.update({key}) | |
| return ignored | |
| """ | |
| # TODO: Take a config object, not kw_args. | |
| for key in config.assigned: | |
| value = getattr(config, key) | |
| """ | |
| def set(self, key, value, ignore_unrecognized=False): | |
| if hasattr(self, key): | |
| setattr(self, key, value) | |
| self.assigned[key] = True | |
| return True | |
| if not ignore_unrecognized: | |
| raise Exception(f"Unrecognized key `{key}` for {type(self)}") | |
| def help(self): | |
| print(ujson.dumps(self.export(), indent=4)) | |
| def __export_value(self, v): | |
| v = v.provenance() if hasattr(v, 'provenance') else v | |
| if isinstance(v, list) and len(v) > 100: | |
| v = (f"list with {len(v)} elements starting with...", v[:3]) | |
| if isinstance(v, dict) and len(v) > 100: | |
| v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3]) | |
| return v | |
| def export(self): | |
| d = dataclasses.asdict(self) | |
| for k, v in d.items(): | |
| d[k] = self.__export_value(v) | |
| return d | |
| class BaseConfig(CoreConfig): | |
| def from_existing(cls, *sources): | |
| kw_args = {} | |
| for source in sources: | |
| if source is None: | |
| continue | |
| local_kw_args = dataclasses.asdict(source) | |
| local_kw_args = {k: local_kw_args[k] for k in source.assigned} | |
| kw_args = {**kw_args, **local_kw_args} | |
| obj = cls(**kw_args) | |
| return obj | |
| def from_deprecated_args(cls, args): | |
| obj = cls() | |
| ignored = obj.configure(ignore_unrecognized=True, **args) | |
| return obj, ignored | |
| def from_path(cls, name): | |
| with open(name) as f: | |
| args = ujson.load(f) | |
| if "config" in args: | |
| args = args["config"] | |
| return cls.from_deprecated_args( | |
| args | |
| ) # the new, non-deprecated version functions the same at this level. | |
| def load_from_checkpoint(cls, checkpoint_path): | |
| if checkpoint_path.endswith(".dnn"): | |
| dnn = torch_load_dnn(checkpoint_path) | |
| config, _ = cls.from_deprecated_args(dnn.get("arguments", {})) | |
| # TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those! | |
| config.set("checkpoint", checkpoint_path) | |
| return config | |
| name_or_path = checkpoint_path | |
| try: | |
| checkpoint_path = hf_hub_download( | |
| repo_id=checkpoint_path, filename="artifact.metadata" | |
| ).split("artifact")[0] | |
| except Exception: | |
| pass | |
| loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata") | |
| if os.path.exists(loaded_config_path): | |
| loaded_config, _ = cls.from_path(loaded_config_path) | |
| loaded_config.set("checkpoint", checkpoint_path) | |
| loaded_config.set("name_or_path", name_or_path) | |
| return loaded_config | |
| return ( | |
| None # can happen if checkpoint_path is something like 'bert-base-uncased' | |
| ) | |
| def load_from_index(cls, index_path): | |
| # FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config). | |
| # This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c) | |
| # default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/') | |
| # index_path = os.path.join(default_index_root, index_path) | |
| # CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index. | |
| try: | |
| metadata_path = os.path.join(index_path, "metadata.json") | |
| loaded_config, _ = cls.from_path(metadata_path) | |
| except: | |
| metadata_path = os.path.join(index_path, "plan.json") | |
| loaded_config, _ = cls.from_path(metadata_path) | |
| return loaded_config | |
| def save(self, path, overwrite=False): | |
| assert overwrite or not os.path.exists(path), path | |
| with open(path, "w") as f: | |
| args = self.export() # dict(self.__config) | |
| args["meta"] = get_metadata_only() | |
| args["meta"]["version"] = "colbert-v0.4" | |
| # TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe! | |
| f.write(ujson.dumps(args, indent=4) + "\n") | |
| def save_for_checkpoint(self, checkpoint_path): | |
| assert not checkpoint_path.endswith( | |
| ".dnn" | |
| ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format." | |
| output_config_path = os.path.join(checkpoint_path, "artifact.metadata") | |
| self.save(output_config_path, overwrite=True) | |
| class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings, | |
| IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings): | |
| pass |