Module ai.rl.alpha_zero
Implementation of the AlphaZero algorithm, based on Monte Carlo Tree Search for zero sum games.
The AlphaZero algorithm consists of two core components: the LearnerWorker
and the
SelfPlayWorker
. The SelfPlayWorker
s generate rollouts of the policy that are then
passed onto the LearnerWorker
where network updates are performed. These are
configurable by the LearnerConfig
and SelfPlayConfig
.
In addition to these, there are two logging servers used for visualization to
tensorboard: LearnerLogger
and SelfPlayLogger
.
For a basic implementation of the AlphaZero algorithm, see the train()
method. This
method uses pure python multiprocessing. However, replacing the python queues by some
other transportation protocol (implementing the queue interface) should allow the
algorithm to run across multiple machines as well. If running across machines, network
parameters must be communicated as well. In the train()
method, the network is simply
put into shared memory.
When evaluating a model, use the mcts()
method.
Expand source code
"""Implementation of the AlphaZero algorithm, based on Monte Carlo Tree Search for
zero sum games.
The AlphaZero algorithm consists of two core components: the `LearnerWorker` and the
`SelfPlayWorker`. The `SelfPlayWorker`s generate rollouts of the policy that are then
passed onto the `LearnerWorker` where network updates are performed. These are
configurable by the `LearnerConfig` and `SelfPlayConfig`.
In addition to these, there are two logging servers used for visualization to
tensorboard: `LearnerLogger` and `SelfPlayLogger`.
For a basic implementation of the AlphaZero algorithm, see the `train` method. This
method uses pure python multiprocessing. However, replacing the python queues by some
other transportation protocol (implementing the queue interface) should allow the
algorithm to run across multiple machines as well. If running across machines, network
parameters must be communicated as well. In the `train` method, the network is simply
put into shared memory.
When evaluating a model, use the `mcts` method.
"""
from ._core.learner_worker import LearnerWorker, LearnerConfig
from ._core.self_play_worker import SelfPlayWorker, SelfPlayConfig
from ._core.logger import Logger
from ._core.mcts import mcts, MCTSConfig
from ._core.mcts_node import MCTSNode
from ._train import train
from . import networks
__all__ = [
"mcts",
"LearnerWorker",
"LearnerConfig",
"SelfPlayWorker",
"SelfPlayConfig",
"Logger",
"MCTSConfig",
"MCTSNode",
"train",
"networks"
]
Sub-modules
ai.rl.alpha_zero.networks
-
Example implementations of networks for some simulators.
Functions
def mcts(state: numpy.ndarray, action_mask: numpy.ndarray, simulator: Base, network: torch.nn.modules.module.Module, config: MCTSConfig, root_node: MCTSNode = None) ‑> MCTSNode
-
Runs the Monte Carlo Tree Search algorithm.
Args
state
:np.ndarray
- Start state.
action_mask
:np.ndarray
- Start action mask.
simulator
:Simulator
- Simulator.
network
:nn.Module
- Network.
config
:MCTSConfig
- Configuration.
simulations
:int
, optional- Number of MCTS steps. Defaults to 50.
root_node
:MCTSNode
, optional- If not None, this node is used as root. Useful when the tree has previously been traversed, i.e. previously computed children are maintained instead of erasing the already computed tree. Defaults to None.
Returns
MCTSNode
- Root node.
Expand source code
def mcts( state: np.ndarray, action_mask: np.ndarray, simulator: simulators.Base, network: nn.Module, config: MCTSConfig, root_node: MCTSNode = None, ) -> MCTSNode: """Runs the Monte Carlo Tree Search algorithm. Args: state (np.ndarray): Start state. action_mask (np.ndarray): Start action mask. simulator (Simulator): Simulator. network (nn.Module): Network. config (MCTSConfig): Configuration. simulations (int, optional): Number of MCTS steps. Defaults to 50. root_node (MCTSNode, optional): If not None, this node is used as root. Useful when the tree has previously been traversed, i.e. previously computed children are maintained instead of erasing the already computed tree. Defaults to None. Returns: MCTSNode: Root node. """ if not simulator.deterministic: raise ValueError( "Cannot run Monte Carlo Tree Search using a stochastic simulator." ) root = ( MCTSNode(state, action_mask, simulator, network, config=config) if root_node is None else root_node ) if not np.array_equal(state, root.state): raise ValueError("Given state and state of the root node differ.") if not np.array_equal(action_mask, root.action_mask): raise ValueError("Given action mask and action mask of the root node differ.") root.rootify() for _ in range(config.simulations): node = root while not node.is_leaf: node = node.select() node.expand() node.backup() return root
def train(simulator: Factory, self_play_workers: int, learner_config: LearnerConfig, self_play_config: SelfPlayConfig, network: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, save_path: str = None, save_period: int = -1, train_time: int = -1)
-
Starts training an AlphaZero model.
Args
simulator
:simulators.Factory
- Simulator factory spawning simulators on which to train the model.
self_play_workers
:int
- Number of self play workers to spawn.
learner_config
:LearnerConfig
- Configuration for the learner worker.
self_play_config
:SelfPlayConfig
- Configuration for the self play worker.
network
:nn.Module
- Network.
optimizer
:optim.Optimizer
- Optimizer.
save_path
:str
, optional- Path to where to store training checkpoints. If None, no checkpoints are stored. Defaults to None.
save_period
:int
, optional- Time (in seconds) between checkpoints. If less than zero, no saves are made. Defaults to -1.
train_time
:int
, optional- Training time in seconds. If less than zero, training is run until the process is interupted. Defaults to -1.
Expand source code
def train( simulator: simulators.Factory, self_play_workers: int, learner_config: LearnerConfig, self_play_config: SelfPlayConfig, network: nn.Module, optimizer: optim.Optimizer, save_path: str = None, save_period: int = -1, train_time: int = -1 ): """Starts training an AlphaZero model. Args: simulator (simulators.Factory): Simulator factory spawning simulators on which to train the model. self_play_workers (int): Number of self play workers to spawn. learner_config (LearnerConfig): Configuration for the learner worker. self_play_config (SelfPlayConfig): Configuration for the self play worker. network (nn.Module): Network. optimizer (optim.Optimizer): Optimizer. save_path (str, optional): Path to where to store training checkpoints. If None, no checkpoints are stored. Defaults to None. save_period (int, optional): Time (in seconds) between checkpoints. If less than zero, no saves are made. Defaults to -1. train_time (int, optional): Training time in seconds. If less than zero, training is run until the process is interupted. Defaults to -1. """ network.share_memory() logger = Logger() log_port = logger.start() log_client = logging.Client("127.0.0.1", log_port) sample_queue = Queue(maxsize=2000) self_play_workers = [ SelfPlayWorker( simulator, network, self_play_config, sample_queue, log_client=log_client, ) for _ in range(self_play_workers) ] learner_worker = LearnerWorker( network, optimizer, learner_config, sample_queue, log_client=log_client, save_path=save_path, save_period=save_period ) learner_worker.start() for worker in self_play_workers: worker.start() start = perf_counter() while train_time < 0 or perf_counter() - start < train_time: sleep(10) learner_worker.terminate() for worker in self_play_workers: worker.terminate() learner_worker.join() for worker in self_play_workers: worker.join()
Classes
class LearnerConfig
-
Configuration of the Learner process.
class LearnerWorker (network: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, config: LearnerConfig, sample_queue:
>, log_client: Client = None, save_path: str = None, save_period: int = -1) -
Process objects represent activity that is run in a separate process
The class is analogous to
threading.Thread
Ancestors
- multiprocessing.context.Process
- multiprocessing.process.BaseProcess
Methods
def run(self)
-
Method to be run in sub-process; can be overridden in sub-class
Expand source code
def run(self): batch_states, batch_masks, batch_policies, batch_z = [], [], [], [] L = 0 self.last_save_time = perf_counter() while True: try: states, masks, policies, z = self.sample_queue.get(timeout=5) N = states.shape[0] while N > 0: M = min(self.config.batch_size - L, N) batch_states.append(states[:M]) states = states[M:] batch_masks.append(masks[:M]) masks = masks[M:] batch_policies.append(policies[:M]) policies = policies[M:] batch_z.append(z[:M]) z = z[M:] N -= M L += M if L >= self.config.batch_size: self.train_step( torch.cat(batch_states), torch.cat(batch_masks), torch.cat(batch_policies), torch.cat(batch_z), ) batch_states, batch_masks, batch_policies, batch_z = ( [], [], [], [], ) L = 0 except Empty: continue
def save(self, states: torch.Tensor, masks: torch.Tensor)
-
Expand source code
def save(self, states: Tensor, masks: Tensor): if self.save_path is None: return save_dir = os.path.join(self.save_path, str(int(time()))) os.makedirs(save_dir, exist_ok=False) model = jit.trace(self.network, (states, masks)) jit.save(model, os.path.join(save_dir, "network.pt"))
def train_step(self, states: torch.Tensor, masks: torch.Tensor, policies: torch.Tensor, z: torch.Tensor)
-
Expand source code
def train_step(self, states: Tensor, masks: Tensor, policies: Tensor, z: Tensor): p, v = self.network(states, masks) loggedp = torch.where( torch.isinf(p), torch.zeros_like(p), torch.log_softmax(p, dim=1) ) loss = (z - v.view(-1)).square().mean() - (policies * loggedp).sum(dim=1).mean() self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.log_client is not None: self.log_client.log("Training/Loss", loss.detach().item()) if ( self.save_period > 0 and perf_counter() - self.last_save_time > self.save_period ): self.save(states, masks) self.last_save_time = perf_counter()
class Logger
-
Logging server for the
LearnerWorker
.Ancestors
Inherited members
class MCTSConfig
-
Configuration for the Monte Carlo Tree Search.
Subclasses
class MCTSNode (state: np.ndarray, action_mask: np.ndarray, simulator: simulators.Base, network: nn.Module, parent: MCTSNode = None, config: mcts.MCTSConfig = None, action: int = None, reward: float = None, terminal: bool = None)
-
Node in the MCTS algorithm.
Args
state
:np.ndarray
- State of this node.
action_mask
:np.ndarray
- Action mask of the state.
simulator
:Simulator
- Simulator used in the roll out
network
:nn.Module
- Network used in the roll out
parent
:MCTSNode
, optional- Parent node. Defaults to None.
config
:MCTSConfig
, optional- Configuration of the MCTS. Defaults to None.
action
:int
, optional- Action that led to this node. Defaults to None.
reward
:float
, optional- Reward obtained upon transitioning to this node.
- Defaults to None.
terminal
:bool
, optional- Whether or not this node is in a terminal state.
Defaults to None.
Instance variables
var action : Optional[int]
-
The action that led to this node. If this node is the root, then this value is
None
.Expand source code
@property def action(self) -> Optional[int]: """The action that led to this node. If this node is the root, then this value is `None`.""" return self._action
var action_mask : numpy.ndarray
-
The action mask of this node.
Expand source code
@property def action_mask(self) -> np.ndarray: """The action mask of this node.""" return self._action_mask
var action_policy : numpy.ndarray
-
Action policy calculated in this node.
Expand source code
@property def action_policy(self) -> np.ndarray: """Action policy calculated in this node.""" distribution = np.power(self._N, 1 / self._config.T) return distribution / np.sum(distribution)
var children : Optional[List[Optional[MCTSNode]]]
-
The children of this node. If this node has not been expanded, then
None
is returned, otherwise the list of (possible children) is returned. If not None, the list consists ofMCTSNode
s on indices representing legal actions. Illegal action indices areNone
. In other words, the child at indexi
corresponds to the node retrieved by actioni
.Expand source code
@property def children(self) -> Optional[List[Optional[MCTSNode]]]: """The children of this node. If this node has not been expanded, then `None` is returned, otherwise the list of (possible children) is returned. If not None, the list consists of `MCTSNode`s on indices representing legal actions. Illegal action indices are `None`. In other words, the child at index `i` corresponds to the node retrieved by action `i`. """ return self._children
var is_leaf : bool
-
True if this node is a leaf node.
Expand source code
@property def is_leaf(self) -> bool: """True if this node is a leaf node.""" return self._children is None
var is_root : bool
-
True if this node is the root.
Expand source code
@property def is_root(self) -> bool: """True if this node is the root.""" return self._parent is None
var is_terminal : bool
-
True if this node is a terminal state.
Expand source code
@property def is_terminal(self) -> bool: """True if this node is a terminal state.""" return self._terminal
var parent : Optional[MCTSNode]
-
The parent of this node.
Expand source code
@property def parent(self) -> Optional[MCTSNode]: """The parent of this node.""" return self._parent
var reward : Optional[bool]
-
The reward obtained on the transition into this node. If this node is the root, then this value is
None
.Expand source code
@property def reward(self) -> Optional[bool]: """The reward obtained on the transition into this node. If this node is the root, then this value is `None`.""" return self._reward
var state : numpy.ndarray
-
The state of this node.
Expand source code
@property def state(self) -> np.ndarray: """The state of this node.""" return self._state
Methods
def add_noise(self)
-
Adds dirchlet noise to the prior probability.
Expand source code
def add_noise(self): """Adds dirchlet noise to the prior probability.""" d = np.random.dirichlet( self._config.alpha * np.ones(self._action_mask.shape[0])[self._action_mask] ) self._P[self._action_mask] = (1 - self._config.epsilon) * self._P[ self._action_mask ] + self._config.epsilon * d
def backup(self)
-
Runs the backpropagation from this node up to the root.
Expand source code
def backup(self): """Runs the backpropagation from this node up to the root.""" if self.is_root: return if self.is_terminal: self.parent._backpropagate(self._action, self._reward) elif self._config.zero_sum_game: self.parent._backpropagate(self._action, -self._V) else: self.parent._backpropagate( self.action, self._reward + self._config.discount_factor * self._V )
def expand(self)
-
Expands the node. If the node has already been expanded, then this is a no-op.
Expand source code
def expand(self): """Expands the node. If the node has already been expanded, then this is a no-op. """ if self._expanded: return self._init_pv() if not self.is_terminal: actions = np.arange(self._action_mask.shape[0])[self._action_mask] states = np.expand_dims(self._state, 0) states = np.repeat(states, actions.shape[0], axis=0) next_states, rewards, terminals, _ = self._simulator.step_bulk( states, actions ) next_masks = ( self._simulator.action_space .as_discrete .action_mask_bulk(next_states) ) self._children = [None] * self._action_mask.shape[0] for next_state, next_mask, reward, terminal, action in zip( next_states, next_masks, rewards, terminals, actions ): self._children[action] = MCTSNode( next_state, next_mask, self._simulator, self._network, parent=self, config=self._config, action=action, reward=reward, terminal=terminal, ) self._N = np.zeros(self._action_mask.shape[0]) self._W = np.zeros(self._action_mask.shape[0]) self._expanded = True
def rootify(self)
-
Converts this node to a root node, cutting ties with all parents, while maintaining its children.
Expand source code
def rootify(self): """Converts this node to a root node, cutting ties with all parents, while maintaining its children.""" if self.is_terminal: raise ValueError("Cannot rootify a terminal state.") self._parent = None self._action = None self._reward = None self._terminal = False if self._P is not None: self.add_noise()
def select(self) ‑> MCTSNode
-
Traverses one step in the tree from this node according to the selection policy.
Raises
ValueError
- If this node is in a terminal state.
Returns
MCTSNode
- The node selected.
Expand source code
def select(self) -> MCTSNode: """Traverses one step in the tree from this node according to the selection policy. Raises: ValueError: If this node is in a terminal state. Returns: MCTSNode: The node selected. """ if self.is_terminal: raise ValueError("Cannot select action from terminal state.") Q = np.zeros_like(self._N) mask = self._N > 0 Q[mask] = self._W[mask] / self._N[mask] if np.any(self._N > 0): U = self._P * np.sqrt(np.sum(self._N)) / (1 + self._N) else: U = self._P QU = Q + self._config.c * U QU[~self._action_mask] = -np.inf return self._children[np.argmax(QU)]
class SelfPlayConfig
-
Configuration for the self play worker.
Ancestors
class SelfPlayWorker (simulator: Factory, network: torch.nn.modules.module.Module, config: SelfPlayConfig, sample_queue:
>, log_client: Client = None) -
Process objects represent activity that is run in a separate process
The class is analogous to
threading.Thread
Ancestors
- multiprocessing.context.Process
- multiprocessing.process.BaseProcess
Methods
def run(self) ‑> None
-
Method to be run in sub-process; can be overridden in sub-class
Expand source code
def run(self) -> None: while True: try: self.sample_queue.put(self.run_episode(), timeout=5) except Full: continue
def run_episode(self)
-
Expand source code
def run_episode(self): states, action_masks, action_policies = [], [], [] terminal = False reward = 0 state = self.simulator.reset() action_mask = self.simulator.action_space.as_discrete.action_mask(state) with torch.no_grad(): start_prior, start_value = self.network( torch.as_tensor(state, dtype=torch.float).unsqueeze_(0), torch.as_tensor(action_mask).unsqueeze_(0), ) start_value = start_value.item() start_prior = start_prior.softmax(dim=1).squeeze_(0) first_action_policy = None first_action = None root = None while not terminal: root = mcts( state, action_mask, self.simulator, self.network, self.config, root_node=root, ) action_policy = root.action_policy if first_action_policy is None: first_action_policy = action_policy states.append(state) action_masks.append(action_mask) action_policies.append(torch.as_tensor(action_policy, dtype=torch.float)) action = np.random.choice( self.simulator.action_space.as_discrete.size, p=action_policy ) if first_action is None: first_action = action state, reward, terminal, _ = self.simulator.step(state, action) action_mask = self.simulator.action_space.as_discrete.action_mask(state) root = root.children[action] if self.log_client is not None: kl_div = -( first_action_policy * torch.log_softmax(start_prior, dim=0).numpy() ).sum() self.log_client.log("Episode/Reward", reward) self.log_client.log("Episode/Start value", start_value) self.log_client.log("Episode/Start KL Div", kl_div) self.log_client.log("Episode/First action", first_action) states = torch.as_tensor(np.stack(states), dtype=torch.float) action_masks = torch.as_tensor(np.stack(action_masks)) action_policies = torch.as_tensor(np.stack(action_policies), dtype=torch.float) z = torch.ones(states.shape[0]) i = torch.arange(1, states.shape[0] + 1, 2) j = torch.arange(2, states.shape[0] + 1, 2) z[-i] *= reward z[-j] *= -reward return states, action_masks, action_policies, z