File size: 3,223 Bytes
a3e24f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import json
import safetensors.torch


class AdderLayer(nn.Module):
    def __init__(self, i_dim: int, o_dim: int):
        super().__init__()
        self.linear = nn.Linear(i_dim, o_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)
        x = F.relu(x)
        x_shape = x.shape
        x = x.reshape(x_shape[:-1] + (-1, 2))
        x = x / x.norm(dim=-1, keepdim=True)
        x = x.reshape(x_shape)
        return x


class AdderEncoder(nn.Module):
    def __init__(self, bits: int = 32):
        super().__init__()
        self.bits = bits

    def forward(self, x: torch.LongTensor) -> torch.Tensor:
        o = torch.zeros(x.shape + (2 * self.bits,), device=x.device)
        for i in range(self.bits):
            v = (x >> i) & 1
            o[..., i * 2] = 1 - v
            o[..., i * 2 + 1] = v
        return o

    def extra_repr(self) -> str:
        return f'bits={self.bits}'


class AdderDecoder(nn.Module):
    def __init__(self, bits: int = 32):
        super().__init__()
        self.bits = bits
    
    def forward(self, x: torch.Tensor) -> torch.LongTensor:
        o = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.long)
        for i in range(self.bits):
            v = x[..., i * 2 + 1] > x[..., i * 2]
            o = o | (v << i)
        return o


class Adder(nn.Module):
    def __init__(self, layer_dims: list[int], bits: int = 32):
        super().__init__()
        self.layer_dims = layer_dims
        self.bits = bits
        self.encoder = AdderEncoder(bits)
        self.encoder_c = AdderEncoder(1)
        self.decoder = AdderDecoder(bits)
        self.decoder_c = AdderDecoder(1)
        self.layers = nn.ModuleList()
        for i_dim, o_dim in zip(layer_dims[:-1], layer_dims[1:]):
            self.layers.append(AdderLayer(i_dim, o_dim))

    @property
    def config(self) -> dict:
        return {
            'layer_dims': self.layer_dims,
            'bits': self.bits
        }

    @classmethod
    def from_pretrained(cls, filepath: str):
        with open(os.path.join(filepath, 'config.json'), 'r') as f:
            config = json.load(f)
        model = cls(**config)
        state_dict = safetensors.torch.load_file(
            os.path.join(filepath, 'model.safetensors'))
        model.load_state_dict(state_dict)
        model.requires_grad_(False)
        return model

    def forward(self, a: torch.LongTensor, b: torch.LongTensor) -> torch.LongTensor:
        assert (0 <= a < 2 ** self.bits).all()
        assert (0 <= b < 2 ** self.bits).all()
        a = self.encoder(a)
        b = self.encoder(b)
        c = self.encoder_c(torch.tensor(0, device=a.device))
        x = torch.cat([a, b, c], dim=-1)
        for m in self.layers:
            x = m(x)
        x, c = x.split([2 * self.bits, 2], dim=-1)
        x = self.decoder(x)
        c = self.decoder_c(c)
        if (c > 0).any():
            raise ValueError("Carry out is not 0")
        return x


def main():
    model = Adder.from_pretrained(os.path.dirname(__file__))
    print(model(torch.tensor(3123), torch.tensor(5929)))


if __name__ == "__main__":
    main()