|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os |
| import json |
| import safetensors.torch |
|
|
|
|
| class AdderLayer(nn.Module): |
| def __init__(self, i_dim: int, o_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(i_dim, o_dim) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.linear(x) |
| x = F.relu(x) |
| x_shape = x.shape |
| x = x.reshape(x_shape[:-1] + (-1, 2)) |
| x = x / x.norm(dim=-1, keepdim=True) |
| x = x.reshape(x_shape) |
| return x |
|
|
|
|
| class AdderEncoder(nn.Module): |
| def __init__(self, bits: int = 32): |
| super().__init__() |
| self.bits = bits |
|
|
| def forward(self, x: torch.LongTensor) -> torch.Tensor: |
| o = torch.zeros(x.shape + (2 * self.bits,), device=x.device) |
| for i in range(self.bits): |
| v = (x >> i) & 1 |
| o[..., i * 2] = 1 - v |
| o[..., i * 2 + 1] = v |
| return o |
|
|
| def extra_repr(self) -> str: |
| return f'bits={self.bits}' |
|
|
|
|
| class AdderDecoder(nn.Module): |
| def __init__(self, bits: int = 32): |
| super().__init__() |
| self.bits = bits |
| |
| def forward(self, x: torch.Tensor) -> torch.LongTensor: |
| o = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.long) |
| for i in range(self.bits): |
| v = x[..., i * 2 + 1] > x[..., i * 2] |
| o = o | (v << i) |
| return o |
|
|
|
|
| class Adder(nn.Module): |
| def __init__(self, layer_dims: list[int], bits: int = 32): |
| super().__init__() |
| self.layer_dims = layer_dims |
| self.bits = bits |
| self.encoder = AdderEncoder(bits) |
| self.encoder_c = AdderEncoder(1) |
| self.decoder = AdderDecoder(bits) |
| self.decoder_c = AdderDecoder(1) |
| self.layers = nn.ModuleList() |
| for i_dim, o_dim in zip(layer_dims[:-1], layer_dims[1:]): |
| self.layers.append(AdderLayer(i_dim, o_dim)) |
|
|
| @property |
| def config(self) -> dict: |
| return { |
| 'layer_dims': self.layer_dims, |
| 'bits': self.bits |
| } |
|
|
| @classmethod |
| def from_pretrained(cls, filepath: str): |
| with open(os.path.join(filepath, 'config.json'), 'r') as f: |
| config = json.load(f) |
| model = cls(**config) |
| state_dict = safetensors.torch.load_file( |
| os.path.join(filepath, 'model.safetensors')) |
| model.load_state_dict(state_dict) |
| model.requires_grad_(False) |
| return model |
|
|
| def forward(self, a: torch.LongTensor, b: torch.LongTensor) -> torch.LongTensor: |
| assert (0 <= a < 2 ** self.bits).all() |
| assert (0 <= b < 2 ** self.bits).all() |
| a = self.encoder(a) |
| b = self.encoder(b) |
| c = self.encoder_c(torch.tensor(0, device=a.device)) |
| x = torch.cat([a, b, c], dim=-1) |
| for m in self.layers: |
| x = m(x) |
| x, c = x.split([2 * self.bits, 2], dim=-1) |
| x = self.decoder(x) |
| c = self.decoder_c(c) |
| if (c > 0).any(): |
| raise ValueError("Carry out is not 0") |
| return x |
|
|
|
|
| def main(): |
| model = Adder.from_pretrained(os.path.dirname(__file__)) |
| print(model(torch.tensor(3123), torch.tensor(5929))) |
|
|
|
|
| if __name__ == "__main__": |
| main() |