| import torch.nn as nn |
| from transformers import XLMRobertaModel |
| from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaPreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| class Smish(nn.Module): |
| def __init__(self): |
| super().__init__() |
| def forward(self, x): |
| return x * (x.sigmoid() + 1).log().tanh() |
|
|
| class NoRefER(XLMRobertaPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| hidden_size = 32 |
| self.config = config |
| self.roberta = XLMRobertaModel(config) |
| self.dense = nn.Sequential( |
| nn.Dropout(config.hidden_dropout_prob), |
| nn.Linear(config.hidden_size, hidden_size, bias = False), |
| nn.Dropout(config.hidden_dropout_prob), Smish(), |
| nn.Linear(hidden_size, 1, bias = False) |
| ) |
|
|
| self.post_init() |
|
|
| def forward(self, positive_input_ids, positive_attention_mask, negative_input_ids, negative_attention_mask, labels, weight=None): |
| |
| positive_inputs = { |
| "input_ids": positive_input_ids |
| } |
| positive = self.dense(self.roberta(**positive_inputs).pooler_output).squeeze(-1) |
|
|
| |
| negative_inputs = { |
| "input_ids": negative_input_ids |
| } |
| negative = self.dense(self.roberta(**negative_inputs).pooler_output).squeeze(-1) |
|
|
| if weight is None: |
| bce = nn.BCEWithLogitsLoss() |
| else: |
| bs = len(positive) |
| weights = (weight.float() * bs) / weight.sum() |
| bce = nn.BCEWithLogitsLoss(weight = weights) |
| loss = bce(positive - negative, labels.float()) |
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=positive.sigmoid()-negative.sigmoid(), |
| ) |
| |
| def score( |
| self, |
| input_ids, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| ): |
| h = self.roberta(input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states,).pooler_output |
|
|
| return self.dense(h).sigmoid().squeeze(-1) |
|
|