Instructions to use riccorl/relik-reader-deberta-v3-base-aida with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use riccorl/relik-reader-deberta-v3-base-aida with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="riccorl/relik-reader-deberta-v3-base-aida", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("riccorl/relik-reader-deberta-v3-base-aida", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Any, Dict, Optional | |
| import torch | |
| from transformers import AutoModel, PreTrainedModel | |
| from transformers.activations import ClippedGELUActivation, GELUActivation | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.modeling_utils import PoolerEndLogits | |
| from .configuration_relik import RelikReaderConfig | |
| class RelikReaderSample: | |
| def __init__(self, **kwargs): | |
| super().__setattr__("_d", {}) | |
| self._d = kwargs | |
| def __getattribute__(self, item): | |
| return super(RelikReaderSample, self).__getattribute__(item) | |
| def __getattr__(self, item): | |
| if item.startswith("__") and item.endswith("__"): | |
| # this is likely some python library-specific variable (such as __deepcopy__ for copy) | |
| # better follow standard behavior here | |
| raise AttributeError(item) | |
| elif item in self._d: | |
| return self._d[item] | |
| else: | |
| return None | |
| def __setattr__(self, key, value): | |
| if key in self._d: | |
| self._d[key] = value | |
| else: | |
| super().__setattr__(key, value) | |
| self._d[key] = value | |
| activation2functions = { | |
| "relu": torch.nn.ReLU(), | |
| "gelu": GELUActivation(), | |
| "gelu_10": ClippedGELUActivation(-10, 10), | |
| } | |
| class PoolerEndLogitsBi(PoolerEndLogits): | |
| def __init__(self, config: PretrainedConfig): | |
| super().__init__(config) | |
| self.dense_1 = torch.nn.Linear(config.hidden_size, 2) | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| start_states: Optional[torch.FloatTensor] = None, | |
| start_positions: Optional[torch.LongTensor] = None, | |
| p_mask: Optional[torch.FloatTensor] = None, | |
| ) -> torch.FloatTensor: | |
| if p_mask is not None: | |
| p_mask = p_mask.unsqueeze(-1) | |
| logits = super().forward( | |
| hidden_states, | |
| start_states, | |
| start_positions, | |
| p_mask, | |
| ) | |
| return logits | |
| class RelikReaderSpanModel(PreTrainedModel): | |
| config_class = RelikReaderConfig | |
| def __init__(self, config: RelikReaderConfig, *args, **kwargs): | |
| super().__init__(config) | |
| # Transformer model declaration | |
| self.config = config | |
| self.transformer_model = ( | |
| AutoModel.from_pretrained(self.config.transformer_model) | |
| if self.config.num_layers is None | |
| else AutoModel.from_pretrained( | |
| self.config.transformer_model, num_hidden_layers=self.config.num_layers | |
| ) | |
| ) | |
| self.transformer_model.resize_token_embeddings( | |
| self.transformer_model.config.vocab_size | |
| + self.config.additional_special_symbols | |
| ) | |
| self.activation = self.config.activation | |
| self.linears_hidden_size = self.config.linears_hidden_size | |
| self.use_last_k_layers = self.config.use_last_k_layers | |
| # named entity detection layers | |
| self.ned_start_classifier = self._get_projection_layer( | |
| self.activation, last_hidden=2, layer_norm=False | |
| ) | |
| self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) | |
| # END entity disambiguation layer | |
| self.ed_start_projector = self._get_projection_layer(self.activation) | |
| self.ed_end_projector = self._get_projection_layer(self.activation) | |
| self.training = self.config.training | |
| # criterion | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| def _get_projection_layer( | |
| self, | |
| activation: str, | |
| last_hidden: Optional[int] = None, | |
| input_hidden=None, | |
| layer_norm: bool = True, | |
| ) -> torch.nn.Sequential: | |
| head_components = [ | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear( | |
| ( | |
| self.transformer_model.config.hidden_size * self.use_last_k_layers | |
| if input_hidden is None | |
| else input_hidden | |
| ), | |
| self.linears_hidden_size, | |
| ), | |
| activation2functions[activation], | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear( | |
| self.linears_hidden_size, | |
| self.linears_hidden_size if last_hidden is None else last_hidden, | |
| ), | |
| ] | |
| if layer_norm: | |
| head_components.append( | |
| torch.nn.LayerNorm( | |
| self.linears_hidden_size if last_hidden is None else last_hidden, | |
| self.transformer_model.config.layer_norm_eps, | |
| ) | |
| ) | |
| return torch.nn.Sequential(*head_components) | |
| def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| mask = mask.unsqueeze(-1) | |
| if next(self.parameters()).dtype == torch.float16: | |
| logits = logits * (1 - mask) - 65500 * mask | |
| else: | |
| logits = logits * (1 - mask) - 1e30 * mask | |
| return logits | |
| def _get_model_features( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: Optional[torch.Tensor], | |
| ): | |
| model_input = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "output_hidden_states": self.use_last_k_layers > 1, | |
| } | |
| if token_type_ids is not None: | |
| model_input["token_type_ids"] = token_type_ids | |
| model_output = self.transformer_model(**model_input) | |
| if self.use_last_k_layers > 1: | |
| model_features = torch.cat( | |
| model_output[1][-self.use_last_k_layers :], dim=-1 | |
| ) | |
| else: | |
| model_features = model_output[0] | |
| return model_features | |
| def compute_ned_end_logits( | |
| self, | |
| start_predictions, | |
| start_labels, | |
| model_features, | |
| prediction_mask, | |
| batch_size, | |
| ) -> Optional[torch.Tensor]: | |
| # todo: maybe when constraining on the spans, | |
| # we should not use a prediction_mask for the end tokens. | |
| # at least we should not during training imo | |
| start_positions = start_labels if self.training else start_predictions | |
| start_positions_indices = ( | |
| torch.arange(start_positions.size(1), device=start_positions.device) | |
| .unsqueeze(0) | |
| .expand(batch_size, -1)[start_positions > 0] | |
| ).to(start_positions.device) | |
| if len(start_positions_indices) > 0: | |
| expanded_features = model_features.repeat_interleave( | |
| torch.sum(start_positions > 0, dim=-1), dim=0 | |
| ) | |
| expanded_prediction_mask = prediction_mask.repeat_interleave( | |
| torch.sum(start_positions > 0, dim=-1), dim=0 | |
| ) | |
| end_logits = self.ned_end_classifier( | |
| hidden_states=expanded_features, | |
| start_positions=start_positions_indices, | |
| p_mask=expanded_prediction_mask, | |
| ) | |
| return end_logits | |
| return None | |
| def compute_classification_logits( | |
| self, | |
| model_features, | |
| special_symbols_mask, | |
| prediction_mask, | |
| batch_size, | |
| start_positions=None, | |
| end_positions=None, | |
| ) -> torch.Tensor: | |
| if start_positions is None or end_positions is None: | |
| start_positions = torch.zeros_like(prediction_mask) | |
| end_positions = torch.zeros_like(prediction_mask) | |
| model_start_features = self.ed_start_projector(model_features) | |
| model_end_features = self.ed_end_projector(model_features) | |
| model_end_features[start_positions > 0] = model_end_features[end_positions > 0] | |
| model_ed_features = torch.cat( | |
| [model_start_features, model_end_features], dim=-1 | |
| ) | |
| # computing ed features | |
| classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() | |
| special_symbols_representation = model_ed_features[special_symbols_mask].view( | |
| batch_size, classes_representations, -1 | |
| ) | |
| logits = torch.bmm( | |
| model_ed_features, | |
| torch.permute(special_symbols_representation, (0, 2, 1)), | |
| ) | |
| logits = self._mask_logits(logits, prediction_mask) | |
| return logits | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| prediction_mask: Optional[torch.Tensor] = None, | |
| special_symbols_mask: Optional[torch.Tensor] = None, | |
| start_labels: Optional[torch.Tensor] = None, | |
| end_labels: Optional[torch.Tensor] = None, | |
| use_predefined_spans: bool = False, | |
| *args, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| batch_size, seq_len = input_ids.shape | |
| model_features = self._get_model_features( | |
| input_ids, attention_mask, token_type_ids | |
| ) | |
| ned_start_labels = None | |
| # named entity detection if required | |
| if use_predefined_spans: # no need to compute spans | |
| ned_start_logits, ned_start_probabilities, ned_start_predictions = ( | |
| None, | |
| None, | |
| ( | |
| torch.clone(start_labels) | |
| if start_labels is not None | |
| else torch.zeros_like(input_ids) | |
| ), | |
| ) | |
| ned_end_logits, ned_end_probabilities, ned_end_predictions = ( | |
| None, | |
| None, | |
| ( | |
| torch.clone(end_labels) | |
| if end_labels is not None | |
| else torch.zeros_like(input_ids) | |
| ), | |
| ) | |
| ned_start_predictions[ned_start_predictions > 0] = 1 | |
| ned_end_predictions[ned_end_predictions > 0] = 1 | |
| else: # compute spans | |
| # start boundary prediction | |
| ned_start_logits = self.ned_start_classifier(model_features) | |
| ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) | |
| ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) | |
| ned_start_predictions = ned_start_probabilities.argmax(dim=-1) | |
| # end boundary prediction | |
| ned_start_labels = ( | |
| torch.zeros_like(start_labels) if start_labels is not None else None | |
| ) | |
| if ned_start_labels is not None: | |
| ned_start_labels[start_labels == -100] = -100 | |
| ned_start_labels[start_labels > 0] = 1 | |
| ned_end_logits = self.compute_ned_end_logits( | |
| ned_start_predictions, | |
| ned_start_labels, | |
| model_features, | |
| prediction_mask, | |
| batch_size, | |
| ) | |
| if ned_end_logits is not None: | |
| ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) | |
| ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) | |
| else: | |
| ned_end_logits, ned_end_probabilities = None, None | |
| ned_end_predictions = ned_start_predictions.new_zeros(batch_size) | |
| # flattening end predictions | |
| # (flattening can happen only if the | |
| # end boundaries were not predicted using the gold labels) | |
| if not self.training and ned_end_logits is not None: | |
| flattened_end_predictions = torch.zeros_like(ned_start_predictions) | |
| row_indices, start_positions = torch.where(ned_start_predictions > 0) | |
| ned_end_predictions[ | |
| ned_end_predictions < start_positions | |
| ] = start_positions[ned_end_predictions < start_positions] | |
| end_spans_repeated = (row_indices + 1) * seq_len + ned_end_predictions | |
| cummax_values, _ = end_spans_repeated.cummax(dim=0) | |
| end_spans_repeated = end_spans_repeated > torch.cat( | |
| (end_spans_repeated[:1], cummax_values[:-1]) | |
| ) | |
| end_spans_repeated[0] = True | |
| ned_start_predictions[ | |
| row_indices[~end_spans_repeated], | |
| start_positions[~end_spans_repeated], | |
| ] = 0 | |
| row_indices, start_positions, ned_end_predictions = ( | |
| row_indices[end_spans_repeated], | |
| start_positions[end_spans_repeated], | |
| ned_end_predictions[end_spans_repeated], | |
| ) | |
| flattened_end_predictions[row_indices, ned_end_predictions] = 1 | |
| total_start_predictions, total_end_predictions = ( | |
| ned_start_predictions.sum(), | |
| flattened_end_predictions.sum(), | |
| ) | |
| assert ( | |
| total_start_predictions == 0 | |
| or total_start_predictions == total_end_predictions | |
| ), ( | |
| f"Total number of start predictions = {total_start_predictions}. " | |
| f"Total number of end predictions = {total_end_predictions}" | |
| ) | |
| ned_end_predictions = flattened_end_predictions | |
| else: | |
| ned_end_predictions = torch.zeros_like(ned_start_predictions) | |
| start_position, end_position = ( | |
| (start_labels, end_labels) | |
| if self.training | |
| else (ned_start_predictions, ned_end_predictions) | |
| ) | |
| # Entity disambiguation | |
| ed_logits = self.compute_classification_logits( | |
| model_features, | |
| special_symbols_mask, | |
| prediction_mask, | |
| batch_size, | |
| start_position, | |
| end_position, | |
| ) | |
| ed_probabilities = torch.softmax(ed_logits, dim=-1) | |
| ed_predictions = torch.argmax(ed_probabilities, dim=-1) | |
| # output build | |
| output_dict = dict( | |
| batch_size=batch_size, | |
| ned_start_logits=ned_start_logits, | |
| ned_start_probabilities=ned_start_probabilities, | |
| ned_start_predictions=ned_start_predictions, | |
| ned_end_logits=ned_end_logits, | |
| ned_end_probabilities=ned_end_probabilities, | |
| ned_end_predictions=ned_end_predictions, | |
| ed_logits=ed_logits, | |
| ed_probabilities=ed_probabilities, | |
| ed_predictions=ed_predictions, | |
| ) | |
| # compute loss if labels | |
| if start_labels is not None and end_labels is not None and self.training: | |
| # named entity detection loss | |
| # start | |
| if ned_start_logits is not None: | |
| ned_start_loss = self.criterion( | |
| ned_start_logits.view(-1, ned_start_logits.shape[-1]), | |
| ned_start_labels.view(-1), | |
| ) | |
| else: | |
| ned_start_loss = 0 | |
| # end | |
| if ned_end_logits is not None: | |
| ned_end_labels = torch.zeros_like(end_labels) | |
| ned_end_labels[end_labels == -100] = -100 | |
| ned_end_labels[end_labels > 0] = 1 | |
| ned_end_loss = self.criterion( | |
| ned_end_logits, | |
| ( | |
| torch.arange( | |
| ned_end_labels.size(1), device=ned_end_labels.device | |
| ) | |
| .unsqueeze(0) | |
| .expand(batch_size, -1)[ned_end_labels > 0] | |
| ).to(ned_end_labels.device), | |
| ) | |
| else: | |
| ned_end_loss = 0 | |
| # entity disambiguation loss | |
| start_labels[ned_start_labels != 1] = -100 | |
| ed_labels = torch.clone(start_labels) | |
| ed_labels[end_labels > 0] = end_labels[end_labels > 0] | |
| ed_loss = self.criterion( | |
| ed_logits.view(-1, ed_logits.shape[-1]), | |
| ed_labels.view(-1), | |
| ) | |
| output_dict["ned_start_loss"] = ned_start_loss | |
| output_dict["ned_end_loss"] = ned_end_loss | |
| output_dict["ed_loss"] = ed_loss | |
| output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss | |
| return output_dict | |
| class RelikReaderREModel(PreTrainedModel): | |
| config_class = RelikReaderConfig | |
| def __init__(self, config, *args, **kwargs): | |
| super().__init__(config) | |
| # Transformer model declaration | |
| # self.transformer_model_name = transformer_model | |
| self.config = config | |
| self.transformer_model = ( | |
| AutoModel.from_pretrained(config.transformer_model) | |
| if config.num_layers is None | |
| else AutoModel.from_pretrained( | |
| config.transformer_model, num_hidden_layers=config.num_layers | |
| ) | |
| ) | |
| self.transformer_model.resize_token_embeddings( | |
| self.transformer_model.config.vocab_size | |
| + config.additional_special_symbols | |
| + config.additional_special_symbols_types, | |
| ) | |
| # named entity detection layers | |
| self.ned_start_classifier = self._get_projection_layer( | |
| config.activation, last_hidden=2, layer_norm=False | |
| ) | |
| self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) | |
| self.relation_disambiguation_loss = ( | |
| config.relation_disambiguation_loss | |
| if hasattr(config, "relation_disambiguation_loss") | |
| else False | |
| ) | |
| if self.config.entity_type_loss and self.config.add_entity_embedding: | |
| input_hidden_ents = 3 * self.config.linears_hidden_size | |
| else: | |
| input_hidden_ents = 2 * self.config.linears_hidden_size | |
| self.re_projector = self._get_projection_layer( | |
| config.activation, | |
| input_hidden=2 * self.transformer_model.config.hidden_size, | |
| hidden=input_hidden_ents, | |
| last_hidden=2 * self.config.linears_hidden_size, | |
| ) | |
| self.re_relation_projector = self._get_projection_layer( | |
| config.activation, | |
| input_hidden=self.transformer_model.config.hidden_size, | |
| ) | |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: | |
| self.re_entities_projector = self._get_projection_layer( | |
| config.activation, | |
| input_hidden=2 * self.transformer_model.config.hidden_size, | |
| ) | |
| self.re_definition_projector = self._get_projection_layer( | |
| config.activation, | |
| ) | |
| self.re_classifier = self._get_projection_layer( | |
| config.activation, | |
| input_hidden=config.linears_hidden_size, | |
| last_hidden=2, | |
| layer_norm=False, | |
| ) | |
| self.training = config.training | |
| # criterion | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| self.criterion_type = torch.nn.BCEWithLogitsLoss() | |
| def _get_projection_layer( | |
| self, | |
| activation: str, | |
| last_hidden: Optional[int] = None, | |
| hidden: Optional[int] = None, | |
| input_hidden=None, | |
| layer_norm: bool = True, | |
| ) -> torch.nn.Sequential: | |
| head_components = [ | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear( | |
| ( | |
| self.transformer_model.config.hidden_size | |
| * self.config.use_last_k_layers | |
| if input_hidden is None | |
| else input_hidden | |
| ), | |
| self.config.linears_hidden_size if hidden is None else hidden, | |
| ), | |
| activation2functions[activation], | |
| torch.nn.Dropout(0.1), | |
| torch.nn.Linear( | |
| self.config.linears_hidden_size if hidden is None else hidden, | |
| self.config.linears_hidden_size if last_hidden is None else last_hidden, | |
| ), | |
| ] | |
| if layer_norm: | |
| head_components.append( | |
| torch.nn.LayerNorm( | |
| ( | |
| self.config.linears_hidden_size | |
| if last_hidden is None | |
| else last_hidden | |
| ), | |
| self.transformer_model.config.layer_norm_eps, | |
| ) | |
| ) | |
| return torch.nn.Sequential(*head_components) | |
| def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| mask = mask.unsqueeze(-1) | |
| if next(self.parameters()).dtype == torch.float16: | |
| logits = logits * (1 - mask) - 65500 * mask | |
| else: | |
| logits = logits * (1 - mask) - 1e30 * mask | |
| return logits | |
| def _get_model_features( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: Optional[torch.Tensor], | |
| ): | |
| model_input = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "output_hidden_states": self.config.use_last_k_layers > 1, | |
| } | |
| if token_type_ids is not None: | |
| model_input["token_type_ids"] = token_type_ids | |
| model_output = self.transformer_model(**model_input) | |
| if self.config.use_last_k_layers > 1: | |
| model_features = torch.cat( | |
| model_output[1][-self.config.use_last_k_layers :], dim=-1 | |
| ) | |
| else: | |
| model_features = model_output[0] | |
| return model_features | |
| def compute_ned_end_logits( | |
| self, | |
| start_predictions, | |
| start_labels, | |
| model_features, | |
| prediction_mask, | |
| batch_size, | |
| mask_preceding: bool = False, | |
| ) -> Optional[torch.Tensor]: | |
| # todo: maybe when constraining on the spans, | |
| # we should not use a prediction_mask for the end tokens. | |
| # at least we should not during training imo | |
| start_positions = start_labels if self.training else start_predictions | |
| start_positions_indices = ( | |
| torch.arange(start_positions.size(1), device=start_positions.device) | |
| .unsqueeze(0) | |
| .expand(batch_size, -1)[start_positions > 0] | |
| ).to(start_positions.device) | |
| if len(start_positions_indices) > 0: | |
| expanded_features = model_features.repeat_interleave( | |
| torch.sum(start_positions > 0, dim=-1), dim=0 | |
| ) | |
| expanded_prediction_mask = prediction_mask.repeat_interleave( | |
| torch.sum(start_positions > 0, dim=-1), dim=0 | |
| ) | |
| if mask_preceding: | |
| expanded_prediction_mask[ | |
| torch.arange( | |
| expanded_prediction_mask.shape[1], | |
| device=expanded_prediction_mask.device, | |
| ) | |
| < start_positions_indices.unsqueeze(1) | |
| ] = 1 | |
| end_logits = self.ned_end_classifier( | |
| hidden_states=expanded_features, | |
| start_positions=start_positions_indices, | |
| p_mask=expanded_prediction_mask, | |
| ) | |
| return end_logits | |
| return None | |
| def compute_relation_logits( | |
| self, | |
| model_entity_features, | |
| special_symbols_features, | |
| ) -> torch.Tensor: | |
| model_subject_object_features = self.re_projector(model_entity_features) | |
| model_subject_features = model_subject_object_features[ | |
| :, :, : model_subject_object_features.shape[-1] // 2 | |
| ] | |
| model_object_features = model_subject_object_features[ | |
| :, :, model_subject_object_features.shape[-1] // 2 : | |
| ] | |
| special_symbols_start_representation = self.re_relation_projector( | |
| special_symbols_features | |
| ) | |
| re_logits = torch.einsum( | |
| "bse,bde,bfe->bsdfe", | |
| model_subject_features, | |
| model_object_features, | |
| special_symbols_start_representation, | |
| ) | |
| re_logits = self.re_classifier(re_logits) | |
| return re_logits | |
| def compute_entity_logits( | |
| self, | |
| model_entity_features, | |
| special_symbols_features, | |
| ) -> torch.Tensor: | |
| model_ed_features = self.re_entities_projector(model_entity_features) | |
| special_symbols_ed_representation = self.re_definition_projector( | |
| special_symbols_features | |
| ) | |
| logits = torch.bmm( | |
| model_ed_features, | |
| torch.permute(special_symbols_ed_representation, (0, 2, 1)), | |
| ) | |
| logits = self._mask_logits( | |
| logits, (model_entity_features == -100).all(2).long() | |
| ) | |
| return logits | |
| def compute_loss(self, logits, labels, mask=None): | |
| logits = logits.reshape(-1, logits.shape[-1]) | |
| labels = labels.reshape(-1).long() | |
| if mask is not None: | |
| return self.criterion(logits[mask], labels[mask]) | |
| return self.criterion(logits, labels) | |
| def compute_ned_type_loss( | |
| self, | |
| disambiguation_labels, | |
| re_ned_entities_logits, | |
| ned_type_logits, | |
| re_entities_logits, | |
| entity_types, | |
| mask, | |
| ): | |
| if self.config.entity_type_loss and self.relation_disambiguation_loss: | |
| return self.criterion_type( | |
| re_ned_entities_logits[disambiguation_labels != -100], | |
| disambiguation_labels[disambiguation_labels != -100], | |
| ) | |
| if self.config.entity_type_loss: | |
| return self.criterion_type( | |
| ned_type_logits[mask], | |
| disambiguation_labels[:, :, :entity_types][mask], | |
| ) | |
| if self.relation_disambiguation_loss: | |
| return self.criterion_type( | |
| re_entities_logits[disambiguation_labels != -100], | |
| disambiguation_labels[disambiguation_labels != -100], | |
| ) | |
| return 0 | |
| def compute_relation_loss(self, relation_labels, re_logits): | |
| return self.compute_loss( | |
| re_logits, relation_labels, relation_labels.view(-1) != -100 | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: torch.Tensor, | |
| prediction_mask: Optional[torch.Tensor] = None, | |
| special_symbols_mask: Optional[torch.Tensor] = None, | |
| special_symbols_mask_entities: Optional[torch.Tensor] = None, | |
| start_labels: Optional[torch.Tensor] = None, | |
| end_labels: Optional[torch.Tensor] = None, | |
| disambiguation_labels: Optional[torch.Tensor] = None, | |
| relation_labels: Optional[torch.Tensor] = None, | |
| relation_threshold: float = None, | |
| is_validation: bool = False, | |
| is_prediction: bool = False, | |
| use_predefined_spans: bool = False, | |
| *args, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| relation_threshold = ( | |
| self.config.threshold if relation_threshold is None else relation_threshold | |
| ) | |
| batch_size = input_ids.shape[0] | |
| model_features = self._get_model_features( | |
| input_ids, attention_mask, token_type_ids | |
| ) | |
| # named entity detection | |
| if use_predefined_spans: | |
| ned_start_logits, ned_start_probabilities, ned_start_predictions = ( | |
| None, | |
| None, | |
| torch.zeros_like(start_labels), | |
| ) | |
| ned_end_logits, ned_end_probabilities, ned_end_predictions = ( | |
| None, | |
| None, | |
| torch.zeros_like(end_labels), | |
| ) | |
| ned_start_predictions[start_labels > 0] = 1 | |
| ned_end_predictions[end_labels > 0] = 1 | |
| ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] | |
| ned_start_labels = start_labels | |
| ned_start_labels[start_labels > 0] = 1 | |
| else: | |
| # start boundary prediction | |
| ned_start_logits = self.ned_start_classifier(model_features) | |
| if is_validation or is_prediction: | |
| ned_start_logits = self._mask_logits( | |
| ned_start_logits, prediction_mask | |
| ) # why? | |
| ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) | |
| ned_start_predictions = ned_start_probabilities.argmax(dim=-1) | |
| # end boundary prediction | |
| ned_start_labels = ( | |
| torch.zeros_like(start_labels) if start_labels is not None else None | |
| ) | |
| # start_labels contain entity id at their position, we just need 1 for start of entity | |
| if ned_start_labels is not None: | |
| ned_start_labels[start_labels == -100] = -100 | |
| ned_start_labels[start_labels > 0] = 1 | |
| # compute end logits only if there are any start predictions. | |
| # For each start prediction, n end predictions are made | |
| ned_end_logits = self.compute_ned_end_logits( | |
| ned_start_predictions, | |
| ned_start_labels, | |
| model_features, | |
| prediction_mask, | |
| batch_size, | |
| True, | |
| ) | |
| if ned_end_logits is not None: | |
| # For each start prediction, n end predictions are made based on | |
| # binary classification ie. argmax at each position. | |
| ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) | |
| ned_end_predictions = ned_end_probabilities.argmax(dim=-1) | |
| else: | |
| ned_end_logits, ned_end_probabilities = None, None | |
| ned_end_predictions = torch.zeros_like(ned_start_predictions) | |
| if is_prediction or is_validation: | |
| end_preds_count = ned_end_predictions.sum(1) | |
| # If there are no end predictions for a start prediction, remove the start prediction | |
| if (end_preds_count == 0).any() and (ned_start_predictions > 0).any(): | |
| ned_start_predictions[ned_start_predictions == 1] = ( | |
| end_preds_count != 0 | |
| ).long() | |
| ned_end_predictions = ned_end_predictions[end_preds_count != 0] | |
| if end_labels is not None: | |
| end_labels = end_labels[~(end_labels == -100).all(2)] | |
| start_position, end_position = ( | |
| (start_labels, end_labels) | |
| if (not is_prediction and not is_validation) | |
| else (ned_start_predictions, ned_end_predictions) | |
| ) | |
| start_counts = (start_position > 0).sum(1) | |
| if (start_counts > 0).any(): | |
| ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) | |
| # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0 | |
| # if is_validation or is_prediction: | |
| # ned_start_predictions[ned_start_predictions == 1] = start_counts | |
| # We can only predict relations if we have start and end predictions | |
| if (end_position > 0).sum() > 0: | |
| ends_count = (end_position > 0).sum(1) | |
| model_subject_features = torch.cat( | |
| [ | |
| torch.repeat_interleave( | |
| model_features[start_position > 0], ends_count, dim=0 | |
| ), # start position features | |
| torch.repeat_interleave(model_features, start_counts, dim=0)[ | |
| end_position > 0 | |
| ], # end position features | |
| ], | |
| dim=-1, | |
| ) | |
| ents_count = torch.nn.utils.rnn.pad_sequence( | |
| torch.split(ends_count, start_counts.tolist()), | |
| batch_first=True, | |
| padding_value=0, | |
| ).sum(1) | |
| model_subject_features = torch.nn.utils.rnn.pad_sequence( | |
| torch.split(model_subject_features, ents_count.tolist()), | |
| batch_first=True, | |
| padding_value=-100, | |
| ) | |
| # if is_validation or is_prediction: | |
| # model_subject_features = model_subject_features[:, :30, :] | |
| # entity disambiguation. Here relation_disambiguation_loss would only be useful to | |
| # reduce the number of candidate relations for the next step, but currently unused. | |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: | |
| (re_ned_entities_logits) = self.compute_entity_logits( | |
| model_subject_features, | |
| model_features[ | |
| special_symbols_mask | special_symbols_mask_entities | |
| ].view(batch_size, -1, model_features.shape[-1]), | |
| ) | |
| entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item() | |
| ned_type_logits = re_ned_entities_logits[:, :, :entity_types] | |
| re_entities_logits = re_ned_entities_logits[:, :, entity_types:] | |
| if self.config.entity_type_loss: | |
| ned_type_probabilities = torch.sigmoid(ned_type_logits) | |
| ned_type_predictions = ned_type_probabilities.argmax(dim=-1) | |
| if self.config.add_entity_embedding: | |
| special_symbols_representation = model_features[ | |
| special_symbols_mask_entities | |
| ].view(batch_size, entity_types, -1) | |
| entities_representation = torch.einsum( | |
| "bsp,bpe->bse", | |
| ned_type_probabilities, | |
| special_symbols_representation, | |
| ) | |
| model_subject_features = torch.cat( | |
| [model_subject_features, entities_representation], dim=-1 | |
| ) | |
| re_entities_probabilities = torch.sigmoid(re_entities_logits) | |
| re_entities_predictions = re_entities_probabilities.round() | |
| else: | |
| ( | |
| ned_type_logits, | |
| ned_type_probabilities, | |
| re_entities_logits, | |
| re_entities_probabilities, | |
| ) = (None, None, None, None) | |
| ned_type_predictions, re_entities_predictions = ( | |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
| ) | |
| # Compute relation logits | |
| re_logits = self.compute_relation_logits( | |
| model_subject_features, | |
| model_features[special_symbols_mask].view( | |
| batch_size, -1, model_features.shape[-1] | |
| ), | |
| ) | |
| re_probabilities = torch.softmax(re_logits, dim=-1) | |
| # we set a thresshold instead of argmax in cause it needs to be tweaked | |
| re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold | |
| re_probabilities = re_probabilities[:, :, :, :, 1] | |
| else: | |
| ( | |
| ned_type_logits, | |
| ned_type_probabilities, | |
| re_entities_logits, | |
| re_entities_probabilities, | |
| ) = (None, None, None, None) | |
| ned_type_predictions, re_entities_predictions = ( | |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
| torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
| ) | |
| re_logits, re_probabilities, re_predictions = ( | |
| torch.zeros( | |
| [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long | |
| ).to(input_ids.device), | |
| torch.zeros( | |
| [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long | |
| ).to(input_ids.device), | |
| torch.zeros( | |
| [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long | |
| ).to(input_ids.device), | |
| ) | |
| # output build | |
| output_dict = dict( | |
| batch_size=batch_size, | |
| ned_start_logits=ned_start_logits, | |
| ned_start_probabilities=ned_start_probabilities, | |
| ned_start_predictions=ned_start_predictions, | |
| ned_end_logits=ned_end_logits, | |
| ned_end_probabilities=ned_end_probabilities, | |
| ned_end_predictions=ned_end_predictions, | |
| ned_type_logits=ned_type_logits, | |
| ned_type_probabilities=ned_type_probabilities, | |
| ned_type_predictions=ned_type_predictions, | |
| re_entities_logits=re_entities_logits, | |
| re_entities_probabilities=re_entities_probabilities, | |
| re_entities_predictions=re_entities_predictions, | |
| re_logits=re_logits, | |
| re_probabilities=re_probabilities, | |
| re_predictions=re_predictions, | |
| ) | |
| if ( | |
| start_labels is not None | |
| and end_labels is not None | |
| and relation_labels is not None | |
| and is_prediction is False | |
| ): | |
| ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels) | |
| end_labels[end_labels > 0] = 1 | |
| ned_end_loss = self.compute_loss(ned_end_logits, end_labels) | |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: | |
| ned_type_loss = self.compute_ned_type_loss( | |
| disambiguation_labels, | |
| re_ned_entities_logits, | |
| ned_type_logits, | |
| re_entities_logits, | |
| entity_types, | |
| (model_subject_features != -100).all(2), | |
| ) | |
| relation_loss = self.compute_relation_loss(relation_labels, re_logits) | |
| # compute loss. We can skip the relation loss if we are in the first epochs (optional) | |
| if self.config.entity_type_loss or self.relation_disambiguation_loss: | |
| output_dict["loss"] = ( | |
| ned_start_loss + ned_end_loss + relation_loss + ned_type_loss | |
| ) / 4 | |
| output_dict["ned_type_loss"] = ned_type_loss | |
| else: | |
| output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + ( | |
| (9 / 10) * relation_loss | |
| ) | |
| output_dict["ned_start_loss"] = ned_start_loss | |
| output_dict["ned_end_loss"] = ned_end_loss | |
| output_dict["re_loss"] = relation_loss | |
| return output_dict | |