Module ai.rl.alpha_zero.networks
Example implementations of networks for some simulators.
Expand source code
"""Example implementations of networks for some simulators."""
from ._tictactoe import TicTacToeNetwork
from ._connect_four import ConnectFourNetwork
__all__ = ["TicTacToeNetwork", "ConnectFourNetwork"]
Classes
class ConnectFourNetwork
-
Example network for the Connect Four simulator.
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, states, action_masks) ‑> Callable[..., Any]
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, states, action_masks): N = states.shape[0] player = states[:, -1].view(N, 1, 1, 1) states = states[:, :-1].view(N, 1, 6, 7) * player x = self.body(states) v = self.value(x).view(N, 1) p = self.policy(x).view(N, -1) p[~action_masks] = -float("inf") return p, v
class TicTacToeNetwork
-
Example implementation of a network for the TicTacToe simulator.
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, states, action_masks) ‑> Callable[..., Any]
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, states, action_masks): N = states.shape[0] player = states[:, -1].view(N, 1, 1, 1) states = states[:, :-1].view(N, 1, 3, 3) * player x = self.body(states).view(N, -1) v = self.value(x) p = self.policy(x) p[~action_masks] = -float("inf") return p, v