Module ai.rl.a3c
Asynchronous Advantage Actor Critic (A3C).
Expand source code
"""Asynchronous Advantage Actor Critic (A3C)."""
from ._agent import Agent
from . import trainer
__all__ = ["trainer", "Agent"]
Sub-modules
ai.rl.a3c.trainer
-
Trainer module for A3C.
Classes
class Agent (network: torch.nn.modules.module.Module, state_dtype: torch.dtype = torch.float32)
-
A3C agent. This agent wraps a network to be used for inference.
Args
network
:nn.Module
- Network with two output heads, policy logits and state value.
state_dtype
:torch.dtype
, optional- Data type of states to be fed into the
network. Defaults to
torch.float32
.
Methods
def act(self, state: Union[numpy.ndarray, torch.Tensor], action_mask: Union[numpy.ndarray, torch.Tensor]) ‑> int
-
Returns an action, given the state and action mask.
Args
state
:Union[np.ndarray, torch.Tensor]
- State.
action_mask
:Union[np.ndarray, torch.Tensor]
- Action mask, indicating legal actions.
Returns
int
- Action index.
Expand source code
def act(self, state: Union[np.ndarray, torch.Tensor], action_mask: Union[np.ndarray, torch.Tensor]) -> int: """Returns an action, given the state and action mask. Args: state (Union[np.ndarray, torch.Tensor]): State. action_mask (Union[np.ndarray, torch.Tensor]): Action mask, indicating legal actions. Returns: int: Action index. """ state = torch.as_tensor(state, dtype=self._state_dtype) action_mask = torch.as_tensor(action_mask, dtype=torch.bool) return self.act_bulk(state.unsqueeze(0), action_mask.unsqueeze(0)).item()
def act_bulk(self, states: Union[numpy.ndarray, torch.Tensor], action_masks: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor
-
Returns a set of actions for the given states and action masks.
Args
states
:Union[np.ndarray, torch.Tensor]
- States.
action_masks
:Union[np.ndarray, torch.Tensor]
- Action masks, indicating legal actions.
Returns
torch.Tensor
- Tensor containing action indices, data type is
torch.long
.
Expand source code
def act_bulk(self, states: Union[np.ndarray, torch.Tensor], action_masks: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: """Returns a set of actions for the given states and action masks. Args: states (Union[np.ndarray, torch.Tensor]): States. action_masks (Union[np.ndarray, torch.Tensor]): Action masks, indicating legal actions. Returns: torch.Tensor: Tensor containing action indices, data type is `torch.long`. """ states = torch.as_tensor(states, dtype=self._state_dtype) action_masks = torch.as_tensor(action_masks, dtype=torch.bool) with torch.no_grad(): p, _ = self._network(states) p[~action_masks] = -float("inf") p = torch.softmax(p, dim=1) return ai.utils.torch.random.choice(p)