| import torch |
| from torch import nn, Tensor |
| from typing import Optional |
| from transformers import DebertaV2PreTrainedModel, DebertaV2Model |
| from .configuration_deberta_multi import MultiHeadDebertaV2Config |
|
|
| class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel): |
|
|
| config_class = MultiHeadDebertaV2Config |
| def __init__(self, config): |
| super().__init__(config) |
| self.deberta = DebertaV2Model(config) |
| self.heads = nn.ModuleList( |
| [nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)] |
| ) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional["Tensor"] = None, |
| attention_mask: Optional["Tensor"] = None, |
| ) -> "Tensor": |
| outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) |
| sequence_output = outputs[0] |
| logits_list = [ |
| head(self.dropout(sequence_output[:, 0, :])) for head in self.heads |
| ] |
| logits = torch.stack(logits_list, dim=1) |
| outputs.logits = logits |
| return outputs |