| from torch import nn |
| from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
| class BasePreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and |
| a simple interface for downloading and loading pretrained models. |
| """ |
|
|
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if ( |
| isinstance(module, nn.Conv2d) |
| or isinstance(module, nn.Embedding) |
| or isinstance(module, nn.Linear) |
| ): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if hasattr(module, "bias") and module.bias is not None: |
| module.bias.data.zero_() |
|
|
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| elif isinstance(module, nn.Parameter): |
| raise ValueError() |
|
|