from collections.abc import Generator, Iterable from dataclasses import dataclass from enum import StrEnum from nltk.tokenize import TreebankWordDetokenizer import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, BatchEncoding, DebertaV2Model, PreTrainedConfig, PreTrainedModel, PreTrainedTokenizer, ) from transformers.modeling_outputs import TokenClassifierOutput class ModelURI(StrEnum): BASE = "microsoft/deberta-v3-base" LARGE = "microsoft/deberta-v3-large" class ConSec(PreTrainedModel): def __init__(self, config: PreTrainedConfig): super().__init__(config) if config.init_basemodel: self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto", dtype=torch.bfloat16) self.config.vocab_size += 2 self.BaseModel.resize_token_embeddings(self.config.vocab_size) else: self.BaseModel = DebertaV2Model(config) config.init_basemodel = False self.loss = nn.CrossEntropyLoss() self.post_init() @classmethod def from_base(cls, base_id: ModelURI): config = AutoConfig.from_pretrained(base_id) config.init_basemodel = True return cls(config) def add_special_tokens(self, start: int, end: int, gloss: int): self.config.start_token = start self.config.end_token = end self.config.gloss_token = gloss def forward(self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs)->TokenClassifierOutput: base_model_output = self.BaseModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, **kwargs) token_vectors = base_model_output.last_hidden_state selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) starts = (input_ids == self.config.start_token).nonzero() ends = (input_ids == self.config.end_token).nonzero() for startpos, endpos in zip(starts, ends, strict=True): selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) gloss_vectors = self.gloss_vectors( input_ids, starts, position_ids, token_vectors ) logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) return TokenClassifierOutput( logits=logits, loss=self.loss(logits, labels) if labels is not None else None, hidden_states=base_model_output.hidden_states if output_hidden_states else None, attentions=base_model_output.attentions if output_attentions else None, ) def gloss_vectors(self,input_ids: torch.Tensor, starts: torch.Tensor, position_ids: torch.Tensor, token_vectors: torch.Tensor)->torch.Tensor: with self.device: vectors = [token_vectors[i,((position_ids[i]==position_ids[i,j])&(input_ids[i]==self.config.gloss_token))] for (i,j) in starts] maxlen = max(vector.shape[0] for vector in vectors) return torch.stack([torch.cat([vector,torch.zeros((maxlen-vector.shape[0],vector.shape[1]), dtype=torch.bfloat16)]) for vector in vectors]) def json_sequencer(sentence:list[dict])->Generator[tuple[list[str], list[str], int]]: for site in sorted([{"span":i, "n_candidates":len(chunk["candidates"])} for (i,chunk) in enumerate(sentence) if "candidates" in chunk], key = lambda x: x["n_candidates"]): words = [word for chunk in sentence[:site["span"]] for word in chunk["words"]] words.append("[START]") words.extend(sentence[site["span"]]["words"]) words.append("[END]") words.extend([word for chunk in sentence[site["span"]+1:] for word in chunk["words"]]) yield (words, sentence[site["span"]]["candidates"], site["span"]) def json_labeller(sentence,tags): for tag in tags: sentence[tag["index"]]["label"]=tag["label"] return sentence class ConSecTagger: def __init__(self,model, tokenizer, ontology, sequencer=json_sequencer, labeller=json_labeller): self.model = model self.tokenizer = tokenizer special_tokens = self.tokenizer.get_added_vocab() self.start_token = special_tokens["[START]"] self.gloss_token = special_tokens["[GLOSS]"] self.sequencer = sequencer self.detokenizer = TreebankWordDetokenizer() self.glosses = {synset.concept:synset.definition for synset in ontology} self.label=labeller def __call__(self,sentence): already_tagged = [] for (words,candidates,index) in self.sequencer(sentence): text = self.detokenizer.detokenize(words) glosses = [''] glosses.extend([self.glosses[candidate] for candidate in candidates]) glosses.extend([self.glosses[previous["label"]] for previous in already_tagged]) with self.model.device: tokens = self.tokenizer(text,"[GLOSS] ".join(glosses), return_tensors="pt") length = tokens.input_ids.shape[1] positions = torch.arange(length) place = (tokens.input_ids==self.start_token).nonzero(as_tuple=True)[1].item() wordpos = tokens.token_to_word(place) gloss_positions = [index.item() for index in (tokens.input_ids==self.gloss_token).nonzero(as_tuple=True)[1]] gloss_positions.append(length) n_candidates = len(candidates) for (i,position) in enumerate(gloss_positions[:-1]): if i