| |
| |
| |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
|
|
| PIECE_SYMBOLS = {0: '.', 1: 'w', 2: 'W', 3: 'b', 4: 'B'} |
|
|
| def _dark(sq): return (sq // 8 + sq % 8) % 2 == 1 |
| def _promotes(sq, c): return (c==1 and sq//8==0) or (c==-1 and sq//8==7) |
|
|
| class Move: |
| __slots__ = ('from_sq', 'to_sq', 'captures', 'is_king_move') |
| def __init__(self, f, t, cap=None, king=False): |
| self.from_sq, self.to_sq, self.captures, self.is_king_move = f, t, cap or [], king |
| def __repr__(self): return f'Move({self.from_sq}→{self.to_sq}, cap={self.captures})' |
|
|
| class Board: |
| def __init__(self): |
| self.pieces = np.zeros(64, dtype=np.int8) |
| self.turn = 1 |
| self.reset() |
|
|
| |
| def reset(self): |
| self.pieces[:] = 0 |
| for sq in [1, 3, 5, 7, 8, 10, 12, 14, 17, 19, 21, 23]: |
| self.pieces[sq] = 3 |
| for sq in [40, 42, 44, 46, 49, 51, 53, 55, 56, 58, 60, 62]: |
| self.pieces[sq] = 1 |
| self.turn = 1 |
|
|
| def copy(self): |
| b = Board() |
| b.pieces = self.pieces.copy() |
| b.turn = self.turn |
| return b |
|
|
| |
| def _dark(self, sq): return (sq // 8 + sq % 8) % 2 == 1 |
|
|
| |
| def _man_captures(self, sq, color, captured=None): |
| captured = captured or set() |
| dirs = (-9, -7, 7, 9) |
| res = [] |
| for d in dirs: |
| mid = sq + d |
| dst = mid + d |
| if 0 <= dst < 64 and self._dark(dst) and 0 <= mid < 64 and self._dark(mid): |
| if mid not in captured and self.pieces[mid] in (3+color, 4+color): |
| if self.pieces[dst] == 0: |
| new_cap = captured | {mid} |
| res.append((dst, new_cap)) |
| res.extend(self._man_captures(dst, color, new_cap)) |
| return res |
|
|
| def _king_captures(self, sq, color, captured=None): |
| captured = captured or set() |
| res = [] |
| for d in (-9, -7, 7, 9): |
| first = None |
| step = 1 |
| while True: |
| mid = sq + d * step |
| if not (0 <= mid < 64 and self._dark(mid)): |
| break |
| piece_mid = self.pieces[mid] |
| if piece_mid != 0: |
| |
| if first is None: |
| |
| if mid not in captured and piece_mid in (3+color, 4+color): |
| first = mid |
| else: |
| break |
| elif mid == first: |
| pass |
| else: |
| break |
| else: |
| if first is not None and mid not in captured: |
| dst = mid |
| new_cap = captured | {first} |
| res.append((dst, new_cap)) |
| res.extend(self._king_captures(dst, color, new_cap)) |
| step += 1 |
| return res |
|
|
|
|
| def _captures(self): |
| color = self.turn |
| moves = [] |
| for sq in range(64): |
| p = self.pieces[sq] |
| if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1): |
| continue |
| if p in (1,3): |
| caps = self._man_captures(sq, color) |
| for to, cap in caps: |
| moves.append(Move(sq, to, list(cap))) |
| else: |
| caps = self._king_captures(sq, color) |
| for to, cap in caps: |
| moves.append(Move(sq, to, list(cap), is_king_move=True)) |
| return moves |
|
|
| def _quiet(self): |
| color = self.turn |
| moves = [] |
| for sq in range(64): |
| p = self.pieces[sq] |
| if p == 0 or (p in (1,2) and color == -1) or (p in (3,4) and color == 1): |
| continue |
| if p in (1,3): |
| dirs = (-9, -7) if color == 1 else (9, 7) |
| for d in dirs: |
| dst = sq + d |
| if 0 <= dst < 64 and self._dark(dst) and self.pieces[dst] == 0: |
| moves.append(Move(sq, dst)) |
| else: |
| for d in (-9, -7, 7, 9): |
| step = 1 |
| while True: |
| dst = sq + d * step |
| if not (0 <= dst < 64 and self._dark(dst)): |
| break |
| if self.pieces[dst] == 0: |
| moves.append(Move(sq, dst, is_king_move=True)) |
| else: |
| break |
| step += 1 |
| return moves |
|
|
| |
| def legal_moves(self): |
| caps = self._captures() |
| return caps if caps else self._quiet() |
|
|
| |
| def make_move(self, move): |
| p = self.pieces[move.from_sq] |
| self.pieces[move.from_sq] = 0 |
| if not move.is_king_move and (move.to_sq // 8 == 0 and self.turn == 1 or move.to_sq // 8 == 7 and self.turn == -1): |
| p += 1 |
| self.pieces[move.to_sq] = p |
| for cap_sq in move.captures: |
| self.pieces[cap_sq] = 0 |
| self.turn = -self.turn |
|
|
| |
| def is_terminal(self): |
| legal = self.legal_moves() |
| if not legal: |
| return True, -self.turn |
| |
| return False, 0 |
|
|
| |
| def __str__(self): |
| rows = [] |
| for r in range(8): |
| row = [PIECE_SYMBOLS[self.pieces[r*8+c]] if self._dark(r*8+c) else ' ' for c in range(8)] |
| rows.append(" ".join(row)) |
| return "\n".join(rows) |
|
|
| |
| |
| class ResidualBlock(nn.Module): |
| def __init__(self, ch=64): |
| super().__init__() |
| self.conv1=nn.Conv2d(ch, ch, 3, padding=1, bias=False) |
| self.bn1=nn.BatchNorm2d(ch) |
| self.conv2=nn.Conv2d(ch, ch, 3, padding=1, bias=False) |
| self.bn2=nn.BatchNorm2d(ch) |
| def forward(self, x): |
| return F.relu(x + self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))) |
|
|
| class ChekaNet(nn.Module): |
| def __init__(self, blocks=3, channels=64): |
| super().__init__() |
| self.conv_in=nn.Conv2d(5, channels, 3, padding=1, bias=False) |
| self.bn_in=nn.BatchNorm2d(channels) |
| self.residuals=nn.Sequential(*[ResidualBlock(channels) for _ in range(blocks)]) |
| self.policy_conv=nn.Conv2d(channels, 1, 1) |
| self.value_conv=nn.Conv2d(channels, 1, 1) |
| self.value_fc=nn.Sequential(nn.Flatten(), nn.Linear(64,128), nn.ReLU(), nn.Linear(128,1), nn.Tanh()) |
| def forward(self, x): |
| x=F.relu(self.bn_in(self.conv_in(x))) |
| x=self.residuals(x) |
| pol=self.policy_conv(x).squeeze(1).view(x.size(0),-1) |
| val=self.value_fc(self.value_conv(x).squeeze(1)) |
| return pol,val |
|
|
| def board_to_tensor(b): |
| planes=np.zeros((5,8,8),np.float32) |
| for sq in range(64): |
| r,c=sq//8,sq%8 |
| p=b.pieces[sq] |
| if p==1: planes[0,r,c]=1 |
| elif p==2: planes[1,r,c]=1 |
| elif p==3: planes[2,r,c]=1 |
| elif p==4: planes[3,r,c]=1 |
| planes[4]=1.0 if b.turn==1 else 0.0 |
| return torch.from_numpy(planes) |
| |
| import math, random |
|
|
| class MCTSNode: |
| def __init__(self, board, parent=None, prior=0): |
| self.board = board.copy() |
| self.parent, self.P, self.N, self.W, self.children = parent, prior, 0, 0.0, {} |
| def Q(self): return self.W / (self.N + 1e-8) |
| def U(self, c_puct=1.0): return c_puct * self.P * math.sqrt(self.parent.N) / (1 + self.N) |
| def is_leaf(self): return len(self.children) == 0 |
|
|
| def expand_leaf(node, net, device): |
| board = node.board |
| legal = board.legal_moves() |
| if not legal: |
| return -1 if board.turn == 1 else 1 |
| tensor = board_to_tensor(board).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logits, v = net(tensor) |
| logits = logits[0].cpu().numpy() |
| v = v.item() |
| mask = np.full(64, -np.inf) |
| for m in legal: mask[m.to_sq] = logits[m.to_sq] |
| probs = torch.softmax(torch.tensor(mask), dim=0).numpy() |
| for m in legal: |
| child = MCTSNode(board.copy(), parent=node, prior=probs[m.to_sq]) |
| child.board.make_move(m) |
| node.children[m] = child |
| return v |
|
|
| def backup(node, v): |
| while node: |
| node.N += 1 |
| node.W += v |
| v = -v |
| node = node.parent |
|
|
| def select_move(board, net, device, sims=400, c_puct=1.0, temp=0.0): |
| root = MCTSNode(board) |
| for _ in range(sims): |
| node = root |
| while not node.is_leaf(): node = max(node.children.values(), key=lambda n: n.Q() + n.U(c_puct)) |
| v = expand_leaf(node, net, device) |
| backup(node, v) |
| visits = [(m, c.N) for m, c in root.children.items()] |
| if temp == 0: |
| move = max(visits, key=lambda x: x[1])[0] |
| else: |
| counts = np.array([v[1] for v in visits]) ** (1 / temp) |
| counts /= counts.sum() |
| move = random.choices([v[0] for v in visits], counts)[0] |
| return move, root |
|
|