Module ai.rl.alpha_zero.networks

Example implementations of networks for some simulators.

Expand source code
"""Example implementations of networks for some simulators."""


from ._tictactoe import TicTacToeNetwork
from ._connect_four import ConnectFourNetwork


__all__ = ["TicTacToeNetwork", "ConnectFourNetwork"]

Classes

class ConnectFourNetwork

Example network for the Connect Four simulator.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, states, action_masks) ‑> Callable[..., Any]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, states, action_masks):
    N = states.shape[0]
    player = states[:, -1].view(N, 1, 1, 1)
    states = states[:, :-1].view(N, 1, 6, 7) * player
    x = self.body(states)
    v = self.value(x).view(N, 1)
    p = self.policy(x).view(N, -1)
    p[~action_masks] = -float("inf")
    return p, v
class TicTacToeNetwork

Example implementation of a network for the TicTacToe simulator.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, states, action_masks) ‑> Callable[..., Any]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, states, action_masks):
    N = states.shape[0]
    player = states[:, -1].view(N, 1, 1, 1)
    states = states[:, :-1].view(N, 1, 3, 3) * player
    x = self.body(states).view(N, -1)
    v = self.value(x)
    p = self.policy(x)
    p[~action_masks] = -float("inf")
    return p, v