Module ai.rl.alpha_zero

Implementation of the AlphaZero algorithm, based on Monte Carlo Tree Search for zero sum games.

The AlphaZero algorithm consists of two core components: the LearnerWorker and the SelfPlayWorker. The SelfPlayWorkers generate rollouts of the policy that are then passed onto the LearnerWorker where network updates are performed. These are configurable by the LearnerConfig and SelfPlayConfig.

In addition to these, there are two logging servers used for visualization to tensorboard: LearnerLogger and SelfPlayLogger.

For a basic implementation of the AlphaZero algorithm, see the train() method. This method uses pure python multiprocessing. However, replacing the python queues by some other transportation protocol (implementing the queue interface) should allow the algorithm to run across multiple machines as well. If running across machines, network parameters must be communicated as well. In the train() method, the network is simply put into shared memory.

When evaluating a model, use the mcts() method.

Expand source code
"""Implementation of the AlphaZero algorithm, based on Monte Carlo Tree Search for
zero sum games.

The AlphaZero algorithm consists of two core components: the `LearnerWorker` and the
`SelfPlayWorker`. The `SelfPlayWorker`s generate rollouts of the policy that are then
passed onto the `LearnerWorker` where network updates are performed. These are
configurable by the `LearnerConfig` and `SelfPlayConfig`.

In addition to these, there are two logging servers used for visualization to
tensorboard: `LearnerLogger` and `SelfPlayLogger`.

For a basic implementation of the AlphaZero algorithm, see the `train` method. This
method uses pure python multiprocessing. However, replacing the python queues by some
other transportation protocol (implementing the queue interface) should allow the
algorithm to run across multiple machines as well. If running across machines, network
parameters must be communicated as well. In the `train` method, the network is simply
put into shared memory.

When evaluating a model, use the `mcts` method.
"""

from ._core.learner_worker import LearnerWorker, LearnerConfig
from ._core.self_play_worker import SelfPlayWorker, SelfPlayConfig
from ._core.logger import Logger
from ._core.mcts import mcts, MCTSConfig
from ._core.mcts_node import MCTSNode
from ._train import train
from . import networks

__all__ = [
    "mcts",
    "LearnerWorker",
    "LearnerConfig",
    "SelfPlayWorker",
    "SelfPlayConfig",
    "Logger",
    "MCTSConfig",
    "MCTSNode",
    "train",
    "networks"
]

Sub-modules

ai.rl.alpha_zero.networks

Example implementations of networks for some simulators.

Functions

def mcts(state: numpy.ndarray, action_mask: numpy.ndarray, simulator: Base, network: torch.nn.modules.module.Module, config: MCTSConfig, root_node: MCTSNode = None) ‑> MCTSNode

Runs the Monte Carlo Tree Search algorithm.

Args

state : np.ndarray
Start state.
action_mask : np.ndarray
Start action mask.
simulator : Simulator
Simulator.
network : nn.Module
Network.
config : MCTSConfig
Configuration.
simulations : int, optional
Number of MCTS steps. Defaults to 50.
root_node : MCTSNode, optional
If not None, this node is used as root. Useful when the tree has previously been traversed, i.e. previously computed children are maintained instead of erasing the already computed tree. Defaults to None.

Returns

MCTSNode
Root node.
Expand source code
def mcts(
    state: np.ndarray,
    action_mask: np.ndarray,
    simulator: simulators.Base,
    network: nn.Module,
    config: MCTSConfig,
    root_node: MCTSNode = None,
) -> MCTSNode:
    """Runs the Monte Carlo Tree Search algorithm.

    Args:
        state (np.ndarray): Start state.
        action_mask (np.ndarray): Start action mask.
        simulator (Simulator): Simulator.
        network (nn.Module): Network.
        config (MCTSConfig): Configuration.
        simulations (int, optional): Number of MCTS steps. Defaults to 50.
        root_node (MCTSNode, optional): If not None, this node is used as root. Useful
            when the tree has previously been traversed, i.e. previously computed
            children are maintained instead of erasing the already computed tree.
            Defaults to None.

    Returns:
        MCTSNode: Root node.
    """

    if not simulator.deterministic:
        raise ValueError(
            "Cannot run Monte Carlo Tree Search using a stochastic simulator."
        )

    root = (
        MCTSNode(state, action_mask, simulator, network, config=config)
        if root_node is None
        else root_node
    )
    if not np.array_equal(state, root.state):
        raise ValueError("Given state and state of the root node differ.")
    if not np.array_equal(action_mask, root.action_mask):
        raise ValueError("Given action mask and action mask of the root node differ.")
    root.rootify()

    for _ in range(config.simulations):
        node = root
        while not node.is_leaf:
            node = node.select()

        node.expand()
        node.backup()
    return root
def train(simulator: Factory, self_play_workers: int, learner_config: LearnerConfig, self_play_config: SelfPlayConfig, network: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, save_path: str = None, save_period: int = -1, train_time: int = -1)

Starts training an AlphaZero model.

Args

simulator : simulators.Factory
Simulator factory spawning simulators on which to train the model.
self_play_workers : int
Number of self play workers to spawn.
learner_config : LearnerConfig
Configuration for the learner worker.
self_play_config : SelfPlayConfig
Configuration for the self play worker.
network : nn.Module
Network.
optimizer : optim.Optimizer
Optimizer.
save_path : str, optional
Path to where to store training checkpoints. If None, no checkpoints are stored. Defaults to None.
save_period : int, optional
Time (in seconds) between checkpoints. If less than zero, no saves are made. Defaults to -1.
train_time : int, optional
Training time in seconds. If less than zero, training is run until the process is interupted. Defaults to -1.
Expand source code
def train(
    simulator: simulators.Factory,
    self_play_workers: int,
    learner_config: LearnerConfig,
    self_play_config: SelfPlayConfig,
    network: nn.Module,
    optimizer: optim.Optimizer,
    save_path: str = None,
    save_period: int = -1,
    train_time: int = -1
):
    """Starts training an AlphaZero model.

    Args:
        simulator (simulators.Factory): Simulator factory spawning simulators
            on which to train the model.
        self_play_workers (int): Number of self play workers to spawn.
        learner_config (LearnerConfig): Configuration for the learner worker.
        self_play_config (SelfPlayConfig): Configuration for the self play worker.
        network (nn.Module): Network.
        optimizer (optim.Optimizer): Optimizer.
        save_path (str, optional): Path to where to store training checkpoints. If None,
            no checkpoints are stored. Defaults to None.
        save_period (int, optional): Time (in seconds) between checkpoints. If less than
            zero, no saves are made. Defaults to -1.
        train_time (int, optional): Training time in seconds. If less than zero,
            training is run until the process is interupted. Defaults to -1.
    """
    network.share_memory()
    logger = Logger()
    log_port = logger.start()
    log_client = logging.Client("127.0.0.1", log_port)

    sample_queue = Queue(maxsize=2000)

    self_play_workers = [
        SelfPlayWorker(
            simulator,
            network,
            self_play_config,
            sample_queue,
            log_client=log_client,
        )
        for _ in range(self_play_workers)
    ]

    learner_worker = LearnerWorker(
        network,
        optimizer,
        learner_config,
        sample_queue,
        log_client=log_client,
        save_path=save_path,
        save_period=save_period
    )

    learner_worker.start()
    for worker in self_play_workers:
        worker.start()

    start = perf_counter()
    while train_time < 0 or perf_counter() - start < train_time:
        sleep(10)

    learner_worker.terminate()
    for worker in self_play_workers:
        worker.terminate()

    learner_worker.join()
    for worker in self_play_workers:
        worker.join()

Classes

class LearnerConfig

Configuration of the Learner process.

class LearnerWorker (network: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, config: LearnerConfig, sample_queue: >, log_client: Client = None, save_path: str = None, save_period: int = -1)

Process objects represent activity that is run in a separate process

The class is analogous to threading.Thread

Ancestors

  • multiprocessing.context.Process
  • multiprocessing.process.BaseProcess

Methods

def run(self)

Method to be run in sub-process; can be overridden in sub-class

Expand source code
def run(self):
    batch_states, batch_masks, batch_policies, batch_z = [], [], [], []
    L = 0
    self.last_save_time = perf_counter()
    while True:
        try:
            states, masks, policies, z = self.sample_queue.get(timeout=5)
            N = states.shape[0]
            while N > 0:
                M = min(self.config.batch_size - L, N)
                batch_states.append(states[:M])
                states = states[M:]
                batch_masks.append(masks[:M])
                masks = masks[M:]
                batch_policies.append(policies[:M])
                policies = policies[M:]
                batch_z.append(z[:M])
                z = z[M:]
                N -= M
                L += M

                if L >= self.config.batch_size:
                    self.train_step(
                        torch.cat(batch_states),
                        torch.cat(batch_masks),
                        torch.cat(batch_policies),
                        torch.cat(batch_z),
                    )
                    batch_states, batch_masks, batch_policies, batch_z = (
                        [],
                        [],
                        [],
                        [],
                    )
                    L = 0
        except Empty:
            continue
def save(self, states: torch.Tensor, masks: torch.Tensor)
Expand source code
def save(self, states: Tensor, masks: Tensor):
    if self.save_path is None:
        return

    save_dir = os.path.join(self.save_path, str(int(time())))
    os.makedirs(save_dir, exist_ok=False)

    model = jit.trace(self.network, (states, masks))
    jit.save(model, os.path.join(save_dir, "network.pt"))
def train_step(self, states: torch.Tensor, masks: torch.Tensor, policies: torch.Tensor, z: torch.Tensor)
Expand source code
def train_step(self, states: Tensor, masks: Tensor, policies: Tensor, z: Tensor):
    p, v = self.network(states, masks)
    loggedp = torch.where(
        torch.isinf(p), torch.zeros_like(p), torch.log_softmax(p, dim=1)
    )

    loss = (z - v.view(-1)).square().mean() - (policies * loggedp).sum(dim=1).mean()
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    if self.log_client is not None:
        self.log_client.log("Training/Loss", loss.detach().item())

    if (
        self.save_period > 0
        and perf_counter() - self.last_save_time > self.save_period
    ):
        self.save(states, masks)
        self.last_save_time = perf_counter()
class Logger

Logging server for the LearnerWorker.

Ancestors

Inherited members

class MCTSConfig

Configuration for the Monte Carlo Tree Search.

Subclasses

class MCTSNode (state: np.ndarray, action_mask: np.ndarray, simulator: simulators.Base, network: nn.Module, parent: MCTSNode = None, config: mcts.MCTSConfig = None, action: int = None, reward: float = None, terminal: bool = None)

Node in the MCTS algorithm.

Args

state : np.ndarray
State of this node.
action_mask : np.ndarray
Action mask of the state.
simulator : Simulator
Simulator used in the roll out
network : nn.Module
Network used in the roll out
parent : MCTSNode, optional
Parent node. Defaults to None.
config : MCTSConfig, optional
Configuration of the MCTS. Defaults to None.
action : int, optional
Action that led to this node. Defaults to None.
reward : float, optional
Reward obtained upon transitioning to this node.
Defaults to None.
terminal : bool, optional
Whether or not this node is in a terminal state.

Defaults to None.

Instance variables

var action : Optional[int]

The action that led to this node. If this node is the root, then this value is None.

Expand source code
@property
def action(self) -> Optional[int]:
    """The action that led to this node. If this node is the root, then this value
    is `None`."""
    return self._action
var action_mask : numpy.ndarray

The action mask of this node.

Expand source code
@property
def action_mask(self) -> np.ndarray:
    """The action mask of this node."""
    return self._action_mask
var action_policy : numpy.ndarray

Action policy calculated in this node.

Expand source code
@property
def action_policy(self) -> np.ndarray:
    """Action policy calculated in this node."""
    distribution = np.power(self._N, 1 / self._config.T)
    return distribution / np.sum(distribution)
var children : Optional[List[Optional[MCTSNode]]]

The children of this node. If this node has not been expanded, then None is returned, otherwise the list of (possible children) is returned. If not None, the list consists of MCTSNodes on indices representing legal actions. Illegal action indices are None. In other words, the child at index i corresponds to the node retrieved by action i.

Expand source code
@property
def children(self) -> Optional[List[Optional[MCTSNode]]]:
    """The children of this node. If this node has not been expanded, then
    `None` is returned, otherwise the list of (possible children) is returned. If
    not None, the list consists of `MCTSNode`s on indices representing legal
    actions. Illegal action indices are `None`. In other words, the child at index
    `i` corresponds to the node retrieved by action `i`.
    """
    return self._children
var is_leaf : bool

True if this node is a leaf node.

Expand source code
@property
def is_leaf(self) -> bool:
    """True if this node is a leaf node."""
    return self._children is None
var is_root : bool

True if this node is the root.

Expand source code
@property
def is_root(self) -> bool:
    """True if this node is the root."""
    return self._parent is None
var is_terminal : bool

True if this node is a terminal state.

Expand source code
@property
def is_terminal(self) -> bool:
    """True if this node is a terminal state."""
    return self._terminal
var parent : Optional[MCTSNode]

The parent of this node.

Expand source code
@property
def parent(self) -> Optional[MCTSNode]:
    """The parent of this node."""
    return self._parent
var reward : Optional[bool]

The reward obtained on the transition into this node. If this node is the root, then this value is None.

Expand source code
@property
def reward(self) -> Optional[bool]:
    """The reward obtained on the transition into this node. If this node is the
    root, then this value is `None`."""
    return self._reward
var state : numpy.ndarray

The state of this node.

Expand source code
@property
def state(self) -> np.ndarray:
    """The state of this node."""
    return self._state

Methods

def add_noise(self)

Adds dirchlet noise to the prior probability.

Expand source code
def add_noise(self):
    """Adds dirchlet noise to the prior probability."""
    d = np.random.dirichlet(
        self._config.alpha * np.ones(self._action_mask.shape[0])[self._action_mask]
    )
    self._P[self._action_mask] = (1 - self._config.epsilon) * self._P[
        self._action_mask
    ] + self._config.epsilon * d
def backup(self)

Runs the backpropagation from this node up to the root.

Expand source code
def backup(self):
    """Runs the backpropagation from this node up to the root."""
    if self.is_root:
        return

    if self.is_terminal:
        self.parent._backpropagate(self._action, self._reward)
    elif self._config.zero_sum_game:
        self.parent._backpropagate(self._action, -self._V)
    else:
        self.parent._backpropagate(
            self.action, self._reward + self._config.discount_factor * self._V
        )
def expand(self)

Expands the node. If the node has already been expanded, then this is a no-op.

Expand source code
def expand(self):
    """Expands the node. If the node has already been expanded, then this is a
    no-op.
    """

    if self._expanded:
        return

    self._init_pv()

    if not self.is_terminal:
        actions = np.arange(self._action_mask.shape[0])[self._action_mask]
        states = np.expand_dims(self._state, 0)
        states = np.repeat(states, actions.shape[0], axis=0)
        next_states, rewards, terminals, _ = self._simulator.step_bulk(
            states, actions
        )
        next_masks = (
            self._simulator.action_space
            .as_discrete
            .action_mask_bulk(next_states)
        )

        self._children = [None] * self._action_mask.shape[0]
        for next_state, next_mask, reward, terminal, action in zip(
            next_states, next_masks, rewards, terminals, actions
        ):
            self._children[action] = MCTSNode(
                next_state,
                next_mask,
                self._simulator,
                self._network,
                parent=self,
                config=self._config,
                action=action,
                reward=reward,
                terminal=terminal,
            )

        self._N = np.zeros(self._action_mask.shape[0])
        self._W = np.zeros(self._action_mask.shape[0])

    self._expanded = True
def rootify(self)

Converts this node to a root node, cutting ties with all parents, while maintaining its children.

Expand source code
def rootify(self):
    """Converts this node to a root node, cutting ties with all parents, while
    maintaining its children."""

    if self.is_terminal:
        raise ValueError("Cannot rootify a terminal state.")

    self._parent = None
    self._action = None
    self._reward = None
    self._terminal = False

    if self._P is not None:
        self.add_noise()
def select(self) ‑> MCTSNode

Traverses one step in the tree from this node according to the selection policy.

Raises

ValueError
If this node is in a terminal state.

Returns

MCTSNode
The node selected.
Expand source code
def select(self) -> MCTSNode:
    """Traverses one step in the tree from this node according to the selection
    policy.

    Raises:
        ValueError: If this node is in a terminal state.

    Returns:
        MCTSNode: The node selected.
    """
    if self.is_terminal:
        raise ValueError("Cannot select action from terminal state.")

    Q = np.zeros_like(self._N)
    mask = self._N > 0
    Q[mask] = self._W[mask] / self._N[mask]
    if np.any(self._N > 0):
        U = self._P * np.sqrt(np.sum(self._N)) / (1 + self._N)
    else:
        U = self._P
    QU = Q + self._config.c * U
    QU[~self._action_mask] = -np.inf
    return self._children[np.argmax(QU)]
class SelfPlayConfig

Configuration for the self play worker.

Ancestors

class SelfPlayWorker (simulator: Factory, network: torch.nn.modules.module.Module, config: SelfPlayConfig, sample_queue: >, log_client: Client = None)

Process objects represent activity that is run in a separate process

The class is analogous to threading.Thread

Ancestors

  • multiprocessing.context.Process
  • multiprocessing.process.BaseProcess

Methods

def run(self) ‑> None

Method to be run in sub-process; can be overridden in sub-class

Expand source code
def run(self) -> None:
    while True:
        try:
            self.sample_queue.put(self.run_episode(), timeout=5)
        except Full:
            continue
def run_episode(self)
Expand source code
def run_episode(self):
    states, action_masks, action_policies = [], [], []
    terminal = False
    reward = 0
    state = self.simulator.reset()
    action_mask = self.simulator.action_space.as_discrete.action_mask(state)

    with torch.no_grad():
        start_prior, start_value = self.network(
            torch.as_tensor(state, dtype=torch.float).unsqueeze_(0),
            torch.as_tensor(action_mask).unsqueeze_(0),
        )
    start_value = start_value.item()
    start_prior = start_prior.softmax(dim=1).squeeze_(0)
    first_action_policy = None
    first_action = None

    root = None
    while not terminal:
        root = mcts(
            state,
            action_mask,
            self.simulator,
            self.network,
            self.config,
            root_node=root,
        )

        action_policy = root.action_policy
        if first_action_policy is None:
            first_action_policy = action_policy

        states.append(state)
        action_masks.append(action_mask)
        action_policies.append(torch.as_tensor(action_policy, dtype=torch.float))

        action = np.random.choice(
            self.simulator.action_space.as_discrete.size, p=action_policy
        )
        if first_action is None:
            first_action = action

        state, reward, terminal, _ = self.simulator.step(state, action)
        action_mask = self.simulator.action_space.as_discrete.action_mask(state)
        root = root.children[action]

    if self.log_client is not None:
        kl_div = -(
            first_action_policy * torch.log_softmax(start_prior, dim=0).numpy()
        ).sum()
        self.log_client.log("Episode/Reward", reward)
        self.log_client.log("Episode/Start value", start_value)
        self.log_client.log("Episode/Start KL Div", kl_div)
        self.log_client.log("Episode/First action", first_action)

    states = torch.as_tensor(np.stack(states), dtype=torch.float)
    action_masks = torch.as_tensor(np.stack(action_masks))
    action_policies = torch.as_tensor(np.stack(action_policies), dtype=torch.float)
    z = torch.ones(states.shape[0])
    i = torch.arange(1, states.shape[0] + 1, 2)
    j = torch.arange(2, states.shape[0] + 1, 2)
    z[-i] *= reward
    z[-j] *= -reward

    return states, action_masks, action_policies, z