| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import chess |
| | import os |
| | import chess.engine as eng |
| | import torch.multiprocessing as mp |
| | import random |
| | from pathlib import Path |
| |
|
| | |
| | CONFIG = { |
| | "stockfish_path": "/Users/aaronvattay/Downloads/stockfish/stockfish-macos-m1-apple-silicon", |
| | "model_path": "chessy_model.pth", |
| | "backup_model_path": "chessy_modelt-1.pth", |
| | "device": torch.device("mps"), |
| | "learning_rate": 1e-4, |
| | "num_games": 30, |
| | "num_epochs": 10, |
| | "stockfish_time_limit": 1.0, |
| | "search_depth": 1, |
| | "epsilon": 4 |
| | } |
| |
|
| | device = CONFIG["device"] |
| |
|
| | def board_to_tensor(board): |
| | piece_encoding = { |
| | 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6, |
| | 'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12 |
| | } |
| |
|
| | tensor = torch.zeros(64, dtype=torch.long) |
| | for square in chess.SQUARES: |
| | piece = board.piece_at(square) |
| | if piece: |
| | tensor[square] = piece_encoding[piece.symbol()] |
| | else: |
| | tensor[square] = 0 |
| |
|
| | return tensor.unsqueeze(0) |
| |
|
| | class NN1(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.embedding = nn.Embedding(13, 64) |
| | self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16) |
| | self.neu = 512 |
| | self.neurons = nn.Sequential( |
| | nn.Linear(4096, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, 64), |
| | nn.ReLU(), |
| | nn.Linear(64, 4) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.embedding(x) |
| | x = x.permute(1, 0, 2) |
| | attn_output, _ = self.attention(x, x, x) |
| | x = attn_output.permute(1, 0, 2).contiguous() |
| | x = x.view(x.size(0), -1) |
| | x = self.neurons(x) |
| | return x |
| |
|
| | lass Policy(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.embedding = nn.Embedding(13, 32) |
| | self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=16) |
| | self.neu = 256 |
| | self.neurons = nn.Sequential( |
| | nn.Linear(64*32, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, self.neu), |
| | nn.ReLU(), |
| | nn.Linear(self.neu, 128), |
| | nn.ReLU(), |
| | nn.Linear(128, 29275), |
| | ) |
| |
|
| | def forward(self, x): |
| | x = chess.Board(x) |
| | color = x.turn |
| | x = board_to_tensor(x) |
| | x = self.embedding(x) |
| | x = x.permute(1, 0, 2) |
| | attn_output, _ = self.attention(x, x, x) |
| | x = attn_output.permute(1, 0, 2).contiguous() |
| | x = x.view(x.size(0), -1) |
| | x = self.neurons(x) * color |
| | return x |
| |
|
| | model = NN1().to(device) |
| | optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"]) |
| | policy = Policy().to(device) |
| | polweight = torch.load("NeoChess/chessy_policy.pth",map_location=device,weights_only=False) |
| | policy.load_state_dict(polweight) |
| |
|
| | try: |
| | model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) |
| | print(f"Loaded model from {CONFIG['model_path']}") |
| | except FileNotFoundError: |
| | try: |
| | model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device)) |
| | print(f"Loaded backup model from {CONFIG['backup_model_path']}") |
| | except FileNotFoundError: |
| | print("No model file found, starting from scratch.") |
| |
|
| | model.train() |
| | criterion = nn.MSELoss() |
| | engine = eng.SimpleEngine.popen_uci(CONFIG["stockfish_path"]) |
| | lim = eng.Limit(time=CONFIG["stockfish_time_limit"]) |
| |
|
| | def get_evaluation(board): |
| | """ |
| | Returns the evaluation of the board from the perspective of the current player. |
| | The model's output is from White's perspective. |
| | """ |
| | tensor = board_to_tensor(board).to(device) |
| | with torch.no_grad(): |
| | evaluation = model(tensor)[0][0].item() |
| | |
| | if board.turn == chess.WHITE: |
| | return evaluation |
| | else: |
| | return -evaluation |
| |
|
| | with open("/usr/local/python/3.12.1/lib/python3.12/site-packages/torchrl/envs/custom/san_moves.txt", "r") as f: |
| | uci_to_index = {line.strip(): i for i, line in enumerate(f)} |
| |
|
| |
|
| | def search(board ,depth ,policy_net=policy, simulations=100, temperature=1.0, device="cpu"): |
| | """ |
| | Monte Carlo search using policy network for move selection |
| | and value network via get_evaluation(). |
| | """ |
| | |
| | depth |
| | with torch.no_grad(): |
| | fen_tensor = torch.tensor([board.fen()], device=device) |
| | logits = policy_net(fen_tensor)["logits"].squeeze(0) |
| | probs = torch.softmax(logits / temperature, dim=-1).cpu().numpy() |
| |
|
| | move_scores = {move: 0 for move in board.legal_moves} |
| |
|
| | for move in board.legal_moves: |
| | total_eval = 0 |
| | for _ in range(simulations): |
| | board.push(move) |
| | eval_score = get_evaluation(board) |
| | total_eval += eval_score |
| | board.pop() |
| | move_scores[move] = total_eval / simulations |
| |
|
| | |
| | for move in move_scores: |
| | move_index = uci_to_index[str(move)] |
| | move_scores[move] *= probs[move_index] |
| |
|
| | |
| | best_move = max(move_scores, key=move_scores.get) |
| | return best_move, move_scores |
| |
|
| | |
| |
|
| | def game_gen(engine_side): |
| | data = [] |
| | mc = 0 |
| | board = chess.Board() |
| | while not board.is_game_over(): |
| | is_bot_turn = board.turn != engine_side |
| | |
| | if is_bot_turn: |
| | evaling = {} |
| | for move in board.legal_moves: |
| | board.push(move) |
| | evaling[move] = -search(board, depth=CONFIG["search_depth"], alpha=float('-inf'), beta=float('inf')) |
| | board.pop() |
| | |
| | if not evaling: |
| | break |
| | |
| | keys = list(evaling.keys()) |
| | logits = torch.tensor(list(evaling.values())).to(device) |
| | probs = torch.softmax(logits,dim=0) |
| | epsilon = min(CONFIG["epsilon"],len(keys)) |
| | bests = torch.multinomial(probs,num_samples=epsilon,replacement=False) |
| | best_idx = bests[torch.argmax(logits[bests])] |
| | move = keys[best_idx.item()] |
| | |
| | else: |
| | result = engine.play(board, lim) |
| | move = result.move |
| |
|
| | if is_bot_turn: |
| | data.append({ |
| | 'fen': board.fen(), |
| | 'move_number': mc, |
| | }) |
| |
|
| | board.push(move) |
| | mc += 1 |
| |
|
| | result = board.result() |
| | c = 0 |
| | if result == '1-0': |
| | c = 10.0 |
| | elif result == '0-1': |
| | c = -10.0 |
| | return data, c, mc |
| | def train(data, c, mc): |
| | for entry in data: |
| | tensor = board_to_tensor(chess.Board(entry['fen'])).to(device) |
| | target = torch.tensor(c * entry['move_number'] / mc, dtype=torch.float32).to(device) |
| | output = model(tensor)[0][0] |
| | loss = criterion(output, target) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| | |
| | print(f"Saving model to {CONFIG['model_path']}") |
| | torch.save(model.state_dict(), CONFIG["model_path"]) |
| | return |
| | def main(): |
| | for i in range(CONFIG["num_epochs"]): |
| | mp.set_start_method('spawn', force=True) |
| | num_games = CONFIG['num_games'] |
| | num_instances = mp.cpu_count() |
| | print(f"Saving backup model to {CONFIG['backup_model_path']}") |
| | torch.save(model.state_dict(), CONFIG["backup_model_path"]) |
| | with mp.Pool(processes=num_instances) as pool: |
| | results_self = pool.starmap(game_gen, [(None,) for _ in range(num_games // 3)]) |
| | results_white = pool.starmap(game_gen, [(chess.WHITE,) for _ in range(num_games // 3)]) |
| | results_black = pool.starmap(game_gen, [(chess.BLACK,) for _ in range(num_games // 3)]) |
| | results = [] |
| | for s, w, b in zip(results_self, results_white, results_black): |
| | results.extend([s, w, b]) |
| | for batch in results: |
| | data, c, mc = batch |
| | print(f"Saving backup model to {CONFIG['backup_model_path']}") |
| | torch.save(model.state_dict(), CONFIG["backup_model_path"]) |
| | if data: |
| | train(data, c, mc) |
| | print("Training complete.") |
| | engine.quit() |
| | if __name__ == "__main__": |
| | main() |