Module ai.rl.dqn.rainbow.networks
Pre defined networks for RainbowDQN.
Expand source code
"""Pre defined networks for RainbowDQN."""
from ._cart_pole import CartPole
__all__ = ["CartPole"]
Classes
class CartPole (use_distributional: bool = False, n_atoms: int = None, std_init: float = 0.5)
-
Example implementation of a network for the CartPole environment.
Args
use_distributional
:bool
, optional- If True, the network outputs in format required by distributional DQN. Defaults to False.
n_atoms
:int
, optional- Number of atoms to use in distributional DQN. Required if distributional DQN is used, otherwise has no effect.
std_init
:float
, optional- Initial standard deviation of noise in NoisyNets. Defaults to 0.5. Set to 0.0 to disable NoisyNets.
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, x) ‑> Callable[..., Any]
-
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward(self, x): x = self._body(x) if self._use_distributional: x = torch.softmax(x.view(-1, 2, self._n_atoms), dim=-1) return x