| import torch |
| import torch.nn as nn |
| from typing import Literal |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| class ProbeConfig(PretrainedConfig): |
| model_type = "linear_probe" |
|
|
| def __init__( |
| self, |
| embedding_dim: int = 768, |
| dropout: float = 0.0, |
| layer_index: int = -1, |
| probe_type: Literal["linear", "nonlinear"] = "linear", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.embedding_dim = embedding_dim |
| self.dropout = dropout |
| self.layer_index = layer_index |
| self.probe_type = probe_type |
|
|
|
|
| class ProbeModel(PreTrainedModel): |
| config_class = ProbeConfig |
|
|
| def __init__(self, config: ProbeConfig): |
| super().__init__(config) |
| self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else None |
| self.linear = nn.Linear(config.embedding_dim, 1) |
|
|
| def forward( |
| self, |
| embeddings: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| if self.dropout is not None: |
| embeddings = self.dropout(embeddings) |
| logits = self.linear(embeddings) |
| return torch.sigmoid(logits) |
|
|