PROBLEM I'm writing a Monte-Carlo tree search algorithm to play chess in Python. I replaced the simulation stage with a custom evaluation function. My code looks perfect but for some reason acts strange. It recognizes instant wins easily enough but cannot recognize checkmate-in-2 moves and checkmate-in-3 moves positions. Any ideas?
WHAT I'VE TRIED I tried giving it more time to search but it still cannot find the best move even when it leads to a guaranteed win in two moves. However, I noticed that results improve when I turn off the custom evaluation and use classic Monte Carlo Tree Search simulation. (To turn off custom evaluation, just don't pass any arguments into the Agent constructor.) But I really need it to work with custom evaluation because I am working on a machine learning technique for board evaluation.
I tried printing out the results of the searches to see which moves the algorithm thinks are good. It consistently ranks the best move in mate-in-2 and mate-in-3 situations among the worst. The rankings are based on the number of times the move was explored (which is how MCTS picks the best moves).
MY CODE I've included the whole code because everything is relevant to the problem. To run this code, you may need to install python-chess (pip install python-chess).
I've struggled with this for more than a week and it's getting frustrating. Any ideas?
import math
import random
import time
import chess
import chess.engine
class Node:
def __init__(self, state, parent, action):
"""Initializes a node structure for a Monte-Carlo search tree."""
self.state = state
self.parent = parent
self.action = action
self.unexplored_actions = list(self.state.legal_moves)
random.shuffle(self.unexplored_actions)
self.colour = self.state.turn
self.children = []
self.w = 0 # number of wins
self.n = 0 # number of simulations
class Agent:
def __init__(self, custom_evaluation=None):
"""Initializes a Monte-Carlo tree search agent."""
if custom_evaluation:
self._evaluate = custom_evaluation
def mcts(self, state, time_limit=float('inf'), node_limit=float('inf')):
"""Runs Monte-Carlo tree search and returns an evaluation."""
nodes_searched = 0
start_time = time.time()
# Initialize the root node.
root = Node(state, None, None)
while (time.time() - start_time) < time_limit and nodes_searched < node_limit:
# Select a leaf node.
leaf = self._select(root)
# Add a new child node to the tree.
if leaf.unexplored_actions:
child = self._expand(leaf)
else:
child = leaf
# Evaluate the node.
result = self._evaluate(child)
# Backpropagate the results.
self._backpropagate(child, result)
nodes_searched += 1
result = max(root.children, key=lambda node: node.n)
return result
def _uct(self, node):
"""Returns the Upper Confidence Bound 1 of a node."""
c = math.sqrt(2)
# We want every WHITE node to choose the worst BLACK node and vice versa.
# Scores for each node are relative to that colour.
w = node.n - node.w
n = node.n
N = node.parent.n
try:
ucb = (w / n) + (c * math.sqrt(math.log(N) / n))
except ZeroDivisionError:
ucb = float('inf')
return ucb
def _select(self, node):
"""Returns a leaf node that either has unexplored actions or is a terminal node."""
while (not node.unexplored_actions) and node.children:
# Pick the child node with highest UCB.
selection = max(node.children, key=self._uct)
# Move to the next node.
node = selection
return node
def _expand(self, node):
"""Adds one child node to the tree."""
# Pick an unexplored action.
action = node.unexplored_actions.pop()
# Create a copy of the node state.
state_copy = node.state.copy()
# Carry out the action on the copy.
state_copy.push(action)
# Create a child node.
child = Node(state_copy, node, action)
# Add the child node to the list of children.
node.children.append(child)
# Return the child node.
return child
def _evaluate(self, node):
"""Returns an evaluation of a given node."""
# If no custom evaluation function was passed into the object constructor,
# use classic simulation.
return self._simulate(node)
def _simulate(self, node):
"""Randomly plays out to the end and returns a static evaluation of the terminal state."""
board = node.state.copy()
while not board.is_game_over():
# Pick a random action.
move = random.choice(list(board.legal_moves))
# Perform the action.
board.push(move)
return self._calculate_static_evaluation(board)
def _backpropagate(self, node, result):
"""Updates a node's values and subsequent parent values."""
# Update the node's values.
node.w += result.pov(node.colour).expectation()
node.n += 1
# Back up values to parent nodes.
while node.parent is not None:
node.parent.w += result.pov(node.parent.colour).expectation()
node.parent.n += 1
node = node.parent
def _calculate_static_evaluation(self, board):
"""Returns a static evaluation of a *terminal* board state."""
result = board.result(claim_draw=True)
if result == '1-0':
wdl = chess.engine.Wdl(wins=1000, draws=0, losses=0)
elif result == '0-1':
wdl = chess.engine.Wdl(wins=0, draws=0, losses=1000)
else:
wdl = chess.engine.Wdl(wins=0, draws=1000, losses=0)
return chess.engine.PovWdl(wdl, chess.WHITE)
def custom_evaluation(node):
"""Returns a static evaluation of a board state."""
board = node.state
# Evaluate terminal states.
if board.is_game_over(claim_draw=True):
result = board.result(claim_draw=True)
if result == '1-0':
wdl = chess.engine.Wdl(wins=1000, draws=0, losses=0)
elif result == '0-1':
wdl = chess.engine.Wdl(wins=0, draws=0, losses=1000)
else:
wdl = chess.engine.Wdl(wins=0, draws=1000, losses=0)
return chess.engine.PovWdl(wdl, chess.WHITE)
# Evaluate material.
material_balance = 0
material_balance += len(board.pieces(chess.PAWN, chess.WHITE)) * +100
material_balance += len(board.pieces(chess.PAWN, chess.BLACK)) * -100
material_balance += len(board.pieces(chess.ROOK, chess.WHITE)) * +500
material_balance += len(board.pieces(chess.ROOK, chess.BLACK)) * -500
material_balance += len(board.pieces(chess.KNIGHT, chess.WHITE)) * +300
material_balance += len(board.pieces(chess.KNIGHT, chess.BLACK)) * -300
material_balance += len(board.pieces(chess.BISHOP, chess.WHITE)) * +300
material_balance += len(board.pieces(chess.BISHOP, chess.BLACK)) * -300
material_balance += len(board.pieces(chess.QUEEN, chess.WHITE)) * +900
material_balance += len(board.pieces(chess.QUEEN, chess.BLACK)) * -900
# TODO: Evaluate mobility.
mobility = 0
# Aggregate values.
centipawn_evaluation = material_balance + mobility
# Convert evaluation from centipawns to wdl.
wdl = chess.engine.Cp(centipawn_evaluation).wdl(model='lichess')
static_evaluation = chess.engine.PovWdl(wdl, chess.WHITE)
return static_evaluation
m1 = chess.Board('8/8/7k/8/8/8/5R2/6R1 w - - 0 1') # f2h2
# WHITE can win in one move. Best move is f2-h2.
m2 = chess.Board('8/6k1/8/8/8/8/1K2R3/5R2 w - - 0 1')
# WHITE can win in two moves. Best move is e2-g2.
m3 = chess.Board('8/8/5k2/8/8/8/3R4/4R3 w - - 0 1')
# WHITE can win in three moves. Best move is d2-f2.
agent = Agent(custom_evaluation)
result = agent.mcts(m2, time_limit=30)
print(result)