Module ai.rl.dqn.rainbow.networks

Pre defined networks for RainbowDQN.

Expand source code
"""Pre defined networks for RainbowDQN."""


from ._cart_pole import CartPole


__all__ = ["CartPole"]

Classes

class CartPole (use_distributional: bool = False, n_atoms: int = None, std_init: float = 0.5)

Example implementation of a network for the CartPole environment.

Args

use_distributional : bool, optional
If True, the network outputs in format required by distributional DQN. Defaults to False.
n_atoms : int, optional
Number of atoms to use in distributional DQN. Required if distributional DQN is used, otherwise has no effect.
std_init : float, optional
Initial standard deviation of noise in NoisyNets. Defaults to 0.5. Set to 0.0 to disable NoisyNets.

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, x) ‑> Callable[..., Any]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    x = self._body(x)
    if self._use_distributional:
        x = torch.softmax(x.view(-1, 2, self._n_atoms), dim=-1)
    return x