import torch import torch.nn as nn import torch.nn.functional as F import os import json import safetensors.torch class AdderLayer(nn.Module): def __init__(self, i_dim: int, o_dim: int): super().__init__() self.linear = nn.Linear(i_dim, o_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear(x) x = F.relu(x) x_shape = x.shape x = x.reshape(x_shape[:-1] + (-1, 2)) x = x / x.norm(dim=-1, keepdim=True) x = x.reshape(x_shape) return x class AdderEncoder(nn.Module): def __init__(self, bits: int = 32): super().__init__() self.bits = bits def forward(self, x: torch.LongTensor) -> torch.Tensor: o = torch.zeros(x.shape + (2 * self.bits,), device=x.device) for i in range(self.bits): v = (x >> i) & 1 o[..., i * 2] = 1 - v o[..., i * 2 + 1] = v return o def extra_repr(self) -> str: return f'bits={self.bits}' class AdderDecoder(nn.Module): def __init__(self, bits: int = 32): super().__init__() self.bits = bits def forward(self, x: torch.Tensor) -> torch.LongTensor: o = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.long) for i in range(self.bits): v = x[..., i * 2 + 1] > x[..., i * 2] o = o | (v << i) return o class Adder(nn.Module): def __init__(self, layer_dims: list[int], bits: int = 32): super().__init__() self.layer_dims = layer_dims self.bits = bits self.encoder = AdderEncoder(bits) self.encoder_c = AdderEncoder(1) self.decoder = AdderDecoder(bits) self.decoder_c = AdderDecoder(1) self.layers = nn.ModuleList() for i_dim, o_dim in zip(layer_dims[:-1], layer_dims[1:]): self.layers.append(AdderLayer(i_dim, o_dim)) @property def config(self) -> dict: return { 'layer_dims': self.layer_dims, 'bits': self.bits } @classmethod def from_pretrained(cls, filepath: str): with open(os.path.join(filepath, 'config.json'), 'r') as f: config = json.load(f) model = cls(**config) state_dict = safetensors.torch.load_file( os.path.join(filepath, 'model.safetensors')) model.load_state_dict(state_dict) model.requires_grad_(False) return model def forward(self, a: torch.LongTensor, b: torch.LongTensor) -> torch.LongTensor: assert (0 <= a < 2 ** self.bits).all() assert (0 <= b < 2 ** self.bits).all() a = self.encoder(a) b = self.encoder(b) c = self.encoder_c(torch.tensor(0, device=a.device)) x = torch.cat([a, b, c], dim=-1) for m in self.layers: x = m(x) x, c = x.split([2 * self.bits, 2], dim=-1) x = self.decoder(x) c = self.decoder_c(c) if (c > 0).any(): raise ValueError("Carry out is not 0") return x def main(): model = Adder.from_pretrained(os.path.dirname(__file__)) print(model(torch.tensor(3123), torch.tensor(5929))) if __name__ == "__main__": main()