| from collections.abc import Generator, Iterable |
| from dataclasses import dataclass |
| from enum import StrEnum |
|
|
| from nltk.tokenize import TreebankWordDetokenizer |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| BatchEncoding, |
| DebertaV2Model, |
| PreTrainedConfig, |
| PreTrainedModel, |
| PreTrainedTokenizer, |
| ) |
| from transformers.modeling_outputs import TokenClassifierOutput |
|
|
| class ModelURI(StrEnum): |
| BASE = "microsoft/deberta-v3-base" |
| LARGE = "microsoft/deberta-v3-large" |
|
|
| class ConSec(PreTrainedModel): |
| def __init__(self, config: PreTrainedConfig): |
| super().__init__(config) |
| if config.init_basemodel: |
| self.BaseModel = AutoModel.from_pretrained(config.name_or_path, |
| device_map="auto", |
| dtype=torch.bfloat16) |
| self.config.vocab_size += 2 |
| self.BaseModel.resize_token_embeddings(self.config.vocab_size) |
| else: |
| self.BaseModel = DebertaV2Model(config) |
| config.init_basemodel = False |
|
|
| self.loss = nn.CrossEntropyLoss() |
| self.post_init() |
|
|
| @classmethod |
| def from_base(cls, base_id: ModelURI): |
| config = AutoConfig.from_pretrained(base_id) |
| config.init_basemodel = True |
| return cls(config) |
| |
| def add_special_tokens(self, start: int, end: int, gloss: int): |
| self.config.start_token = start |
| self.config.end_token = end |
| self.config.gloss_token = gloss |
| |
| def forward(self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| token_type_ids: torch.Tensor | None = None, |
| position_ids: torch.Tensor | None = None, |
| inputs_embeds: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None, |
| **kwargs)->TokenClassifierOutput: |
| base_model_output = self.BaseModel(input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| **kwargs) |
| token_vectors = base_model_output.last_hidden_state |
| selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) |
| starts = (input_ids == self.config.start_token).nonzero() |
| ends = (input_ids == self.config.end_token).nonzero() |
| for startpos, endpos in zip(starts, ends, strict=True): |
| selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 |
| entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) |
| gloss_vectors = self.gloss_vectors( |
| input_ids, starts, position_ids, token_vectors |
| ) |
| logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) |
|
|
| return TokenClassifierOutput( |
| logits=logits, |
| loss=self.loss(logits, labels) if labels is not None else None, |
| hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| attentions=base_model_output.attentions if output_attentions else None, |
| ) |
| |
| def gloss_vectors(self,input_ids: torch.Tensor, |
| starts: torch.Tensor, |
| position_ids: torch.Tensor, |
| token_vectors: torch.Tensor)->torch.Tensor: |
| with self.device: |
| vectors = [token_vectors[i,((position_ids[i]==position_ids[i,j])&(input_ids[i]==self.config.gloss_token))] |
| for (i,j) in starts] |
| maxlen = max(vector.shape[0] for vector in vectors) |
| return torch.stack([torch.cat([vector,torch.zeros((maxlen-vector.shape[0],vector.shape[1]), |
| dtype=torch.bfloat16)]) |
| for vector in vectors]) |
| |
| def json_sequencer(sentence:list[dict])->Generator[tuple[list[str], list[str], int]]: |
| for site in sorted([{"span":i, |
| "n_candidates":len(chunk["candidates"])} |
| for (i,chunk) in enumerate(sentence) |
| if "candidates" in chunk], |
| key = lambda x: x["n_candidates"]): |
| words = [word for chunk in sentence[:site["span"]] |
| for word in chunk["words"]] |
| words.append("[START]") |
| words.extend(sentence[site["span"]]["words"]) |
| words.append("[END]") |
| words.extend([word for chunk in sentence[site["span"]+1:] |
| for word in chunk["words"]]) |
| yield (words, |
| sentence[site["span"]]["candidates"], |
| site["span"]) |
| |
| def json_labeller(sentence,tags): |
| for tag in tags: |
| sentence[tag["index"]]["label"]=tag["label"] |
| return sentence |
| |
| class ConSecTagger: |
| def __init__(self,model, |
| tokenizer, |
| ontology, |
| sequencer=json_sequencer, |
| labeller=json_labeller): |
| self.model = model |
| self.tokenizer = tokenizer |
| special_tokens = self.tokenizer.get_added_vocab() |
| self.start_token = special_tokens["[START]"] |
| self.gloss_token = special_tokens["[GLOSS]"] |
| self.sequencer = sequencer |
| self.detokenizer = TreebankWordDetokenizer() |
| self.glosses = {synset.concept:synset.definition |
| for synset in ontology} |
| self.label=labeller |
| |
| |
| def __call__(self,sentence): |
| already_tagged = [] |
| for (words,candidates,index) in self.sequencer(sentence): |
| text = self.detokenizer.detokenize(words) |
| glosses = [''] |
| glosses.extend([self.glosses[candidate] for candidate in candidates]) |
| glosses.extend([self.glosses[previous["label"]] for previous in already_tagged]) |
| with self.model.device: |
| tokens = self.tokenizer(text,"[GLOSS] ".join(glosses), |
| return_tensors="pt") |
| length = tokens.input_ids.shape[1] |
| positions = torch.arange(length) |
| place = (tokens.input_ids==self.start_token).nonzero(as_tuple=True)[1].item() |
| wordpos = tokens.token_to_word(place) |
| gloss_positions = [index.item() |
| for index in (tokens.input_ids==self.gloss_token).nonzero(as_tuple=True)[1]] |
| gloss_positions.append(length) |
| n_candidates = len(candidates) |
| for (i,position) in enumerate(gloss_positions[:-1]): |
| if i<n_candidates: |
| end = (place + gloss_positions[i+1]-position) |
| positions[position:gloss_positions[i+1]] = torch.arange(place,end) |
| else: |
| known = already_tagged[i-n_candidates] |
| start = tokens.word_to_tokens(known["place"]).start |
| end = (start + gloss_positions[i+1] - position) |
| positions[position:gloss_positions[i+1]] = torch.arange(start,end) |
| prediction = self.model(input_ids=tokens.input_ids, |
| attention_mask=tokens.attention_mask, |
| token_type_ids=tokens.token_type_ids, |
| position_ids=positions.reshape((1,length))) |
| try: |
| label = candidates[prediction.logits.argmax()] |
| except IndexError: |
| print(text) |
| print(gloss_positions) |
| print([positions[pos].item() for pos in gloss_positions[:-1]]) |
| print(already_tagged) |
| print(candidates) |
| print(prediction.logits) |
| print(prediction.logits.argmax()) |
| raise |
| already_tagged.append({"label":label, |
| "place":wordpos, |
| "index":index}) |
| return(self.label(sentence,already_tagged)) |
| |