Module ai.rl.dqn.rainbow
Rainbow DQN.
Expand source code
"""Rainbow DQN."""
from ._agent import Agent
from ._agent_config import AgentConfig
from . import networks, trainers
__all__ = ["Agent", "AgentConfig", "networks", "trainers"]
Sub-modules
ai.rl.dqn.rainbow.networks
-
Pre defined networks for RainbowDQN.
ai.rl.dqn.rainbow.trainers
-
RainbowDQN trainers.
Classes
class Agent (config: AgentConfig, network: Factory[torch.nn.modules.module.Module], optimizer: Factory[torch.optim.optimizer.Optimizer] = None, inference_mode: bool = False, replay_init_lazily: bool = True)
-
RainbowDQN Agent.
Args
config
:AgentConfig
- Agent configuration.
network
:Factory[nn.Module]
- Network wrapped in a
Factory
. optimizer
:Factory[optim.Optimizer]
- Optimizer, wrapped in a
Factory
. Model parameters are passed to the optimizer when instanced. inference_mode
:bool
, optional- If
True
, the agent can only be used for acting. Saves memory by not initializing a replay buffer. Defaults to False. replay_init_lazily
:bool
, optional- If
True
, the replay buffer is initialized lazily, i.e. when the first observation is added. Defaults toTrue
.
Instance variables
var config : AgentConfig
-
Agent configuration in use.
Expand source code
@property def config(self) -> AgentConfig: """Agent configuration in use.""" return self._config
var model_factory : Factory[torch.nn.modules.module.Module]
-
Model factory used by the agent.
Expand source code
@property def model_factory(self) -> Factory[nn.Module]: """Model factory used by the agent.""" return self._network_factory
var model_instance : torch.nn.modules.module.Module
-
Model instance.
Expand source code
@property def model_instance(self) -> nn.Module: """Model instance.""" return self._network
Methods
def act(self, states: Union[numpy.ndarray, torch.Tensor], action_masks: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor
-
Returns the greedy action for the given states and action masks.
Args
states
:Union[Tensor, ndarray]
- States.
action_masks
:Union[Tensor, ndarray]
- Action masks.
Returns
Tensor
- Tensor of dtype
torch.long
.
Expand source code
def act( self, states: Union[Tensor, ndarray], action_masks: Union[Tensor, ndarray] ) -> Tensor: """Returns the greedy action for the given states and action masks. Args: states (Union[Tensor, ndarray]): States. action_masks (Union[Tensor, ndarray]): Action masks. Returns: Tensor: Tensor of dtype `torch.long`. """ with torch.no_grad(): return _get_actions( torch.as_tensor( action_masks, dtype=torch.bool, device=self._config.network_device ), self._network( torch.as_tensor( states, dtype=torch.float32, device=self._config.network_device ) ), self._config.use_distributional, self._z, ).cpu()
def act_single(self, state: Union[numpy.ndarray, torch.Tensor], action_mask: Union[numpy.ndarray, torch.Tensor]) ‑> int
-
Returns the greedy action for one state-action mask pair.
Args
state
:Union[Tensor, ndarray]
- State.
action_mask
:Union[Tensor, ndarray]
- Action mask.
Returns
int
- Action index.
Expand source code
def act_single( self, state: Union[Tensor, ndarray], action_mask: Union[Tensor, ndarray] ) -> int: """Returns the greedy action for one state-action mask pair. Args: state (Union[Tensor, ndarray]): State. action_mask (Union[Tensor, ndarray]): Action mask. Returns: int: Action index. """ return self.act( torch.as_tensor( state, dtype=torch.float32, device=self._config.network_device ).unsqueeze_(0), torch.as_tensor( action_mask, dtype=torch.bool, device=self._config.network_device ).unsqueeze_(0), )[0]
def buffer_size(self) ‑> int
-
Returns
int
- The current size of the replay buffer.
Expand source code
def buffer_size(self) -> int: """ Returns: int: The current size of the replay buffer. """ return 0 if self._buffer is None else self._buffer.size
def inference_mode(self) ‑> Agent
-
Returns a copy of this agent, but in inference mode.
Returns
Agent
- A shallow copy of the agent, capable of inference only.
Expand source code
def inference_mode(self) -> "Agent": """Returns a copy of this agent, but in inference mode. Returns: Agent: A shallow copy of the agent, capable of inference only.""" return Agent(self._config, self._network, inference_mode=True)
def observe(self, states: Union[numpy.ndarray, torch.Tensor], actions: Union[numpy.ndarray, torch.Tensor], rewards: Union[numpy.ndarray, torch.Tensor], terminals: Union[numpy.ndarray, torch.Tensor], next_states: Union[numpy.ndarray, torch.Tensor], next_action_masks: Union[numpy.ndarray, torch.Tensor], errors: Union[numpy.ndarray, torch.Tensor])
-
Adds a batch of experiences to the replay
Args
states
:Union[Tensor, ndarray]
- States
actions
:Union[Tensor, ndarray]
- Actions
rewards
:Union[Tensor, ndarray]
- Rewards
terminals
:Union[Tensor, ndarray]
- Terminal flags
next_states
:Union[Tensor, ndarray]
- Next states
next_action_masks
:Union[Tensor, ndarray]
- Next action masks
errors
:Union[Tensor, ndarray]
- TD errors. NaN values are replaced by appropriate initialization value.
Expand source code
def observe( self, states: Union[Tensor, ndarray], actions: Union[Tensor, ndarray], rewards: Union[Tensor, ndarray], terminals: Union[Tensor, ndarray], next_states: Union[Tensor, ndarray], next_action_masks: Union[Tensor, ndarray], errors: Union[Tensor, ndarray], ): """Adds a batch of experiences to the replay Args: states (Union[Tensor, ndarray]): States actions (Union[Tensor, ndarray]): Actions rewards (Union[Tensor, ndarray]): Rewards terminals (Union[Tensor, ndarray]): Terminal flags next_states (Union[Tensor, ndarray]): Next states next_action_masks (Union[Tensor, ndarray]): Next action masks errors (Union[Tensor, ndarray]): TD errors. NaN values are replaced by appropriate initialization value. """ if self._buffer is None: self._initialize_replay(self._config) errors = torch.as_tensor(errors, dtype=torch.float32) errors[errors.isnan()] = self._max_error self._buffer.add( ( torch.as_tensor( states, dtype=torch.float32, device=self._config.replay_device ), torch.as_tensor( actions, dtype=torch.long, device=self._config.replay_device ), torch.as_tensor( rewards, dtype=torch.float32, device=self._config.replay_device ), torch.as_tensor( terminals, dtype=torch.bool, device=self._config.replay_device ), torch.as_tensor( next_states, dtype=torch.float32, device=self._config.replay_device ), torch.as_tensor( next_action_masks, dtype=torch.bool, device=self._config.replay_device, ), ), errors, )
def observe_single(self, state: Union[numpy.ndarray, torch.Tensor], action: int, reward: float, terminal: bool, next_state: Union[numpy.ndarray, torch.Tensor], next_action_mask: Union[numpy.ndarray, torch.Tensor], error: float)
-
Adds a single experience to the replay buffer.
Args
state
:Union[Tensor, ndarray]
- State
action
:int
- Action
reward
:float
- Reward
terminal
:bool
- True if
next_state
is a terminal state next_state
:Union[Tensor, ndarray]
- Next state
next_action_mask
:Union[Tensor, ndarray]
- Next action mask
error
:float
- TD error. NaN values are replaced by appropriate initialization value.
Expand source code
def observe_single( self, state: Union[Tensor, ndarray], action: int, reward: float, terminal: bool, next_state: Union[Tensor, ndarray], next_action_mask: Union[Tensor, ndarray], error: float, ): """Adds a single experience to the replay buffer. Args: state (Union[Tensor, ndarray]): State action (int): Action reward (float): Reward terminal (bool): True if `next_state` is a terminal state next_state (Union[Tensor, ndarray]): Next state next_action_mask (Union[Tensor, ndarray]): Next action mask error (float): TD error. NaN values are replaced by appropriate initialization value. """ self.observe( torch.as_tensor( state, dtype=torch.float32, device=self._config.replay_device ).unsqueeze_(0), torch.tensor([action], dtype=torch.long, device=self._config.replay_device), torch.tensor( [reward], dtype=torch.float32, device=self._config.replay_device ), torch.tensor( [terminal], dtype=torch.bool, device=self._config.replay_device ), torch.as_tensor( next_state, dtype=torch.float32, device=self._config.replay_device ).unsqueeze_(0), torch.as_tensor( next_action_mask, dtype=torch.bool, device=self._config.replay_device ).unsqueeze_(0), torch.tensor( [error], dtype=torch.float32, device=self._config.replay_device ), )
def q_values(self, states: Union[numpy.ndarray, torch.Tensor], action_masks: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor
-
Computes the Q-values for each state-action pair. Illegal actions are given a values -inf.
Args
states
:Union[Tensor, ndarray]
- States.
action_masks
:Union[Tensor, ndarray]
- Action masks.
Returns
Tensor
- Q-values.
Expand source code
def q_values( self, states: Union[Tensor, ndarray], action_masks: Union[Tensor, ndarray] ) -> Tensor: """Computes the Q-values for each state-action pair. Illegal actions are given a values -inf. Args: states (Union[Tensor, ndarray]): States. action_masks (Union[Tensor, ndarray]): Action masks. Returns: Tensor: Q-values. """ states = torch.as_tensor( states, dtype=torch.float32, device=self._config.network_device ) action_masks = torch.as_tensor( action_masks, dtype=torch.bool, device=self._config.network_device ) with torch.no_grad(): if self.config.use_distributional: values = (self._z.view(1, 1, -1) * self._network(states)).sum(2) else: values = self._network(states) return _apply_masks(values, action_masks).cpu()
def q_values_single(self, state: Union[numpy.ndarray, torch.Tensor], action_mask: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor
-
Computes the Q-values of all state-action pairs, given a single state.
Args
state
:Union[Tensor, ndarray]
- State.
action_mask
:Union[Tensor, ndarray]
- Action mask.
Returns
Tensor
- Q-values, illegal actions have value -inf.
Expand source code
def q_values_single( self, state: Union[Tensor, ndarray], action_mask: Union[Tensor, ndarray] ) -> Tensor: """Computes the Q-values of all state-action pairs, given a single state. Args: state (Union[Tensor, ndarray]): State. action_mask (Union[Tensor, ndarray]): Action mask. Returns: Tensor: Q-values, illegal actions have value -inf. """ return self.q_values( torch.as_tensor( state, dtype=torch.float32, device=self._config.network_device ).unsqueeze_(0), torch.as_tensor( action_mask, dtype=torch.bool, device=self._config.network_device ).unsqueeze_(0), )[0]
def set_logging_client(self, client: Client)
-
Sets (and overrides previously set) logging client used by the agent. The agent outputs tensorboard logs through this client.
Args
client
:logging.Client
- Client.
Expand source code
def set_logging_client(self, client: logging.Client): """Sets (and overrides previously set) logging client used by the agent. The agent outputs tensorboard logs through this client. Args: client (logging.Client): Client. """ self._logging_client = client
def train_step(self)
-
Executes one training step.
Expand source code
def train_step(self): """Executes one training step.""" ( data, sample_probs, sample_ids, ) = self._buffer.sample(self._config.batch_size) data = (x.to(self._config.network_device) for x in data) if self._config.use_distributional: loss = self._get_distributional_loss(*data) else: loss = self._get_td_loss(*data) self._optimizer.zero_grad() if self._config.use_prioritized_experience_replay: beta = min( max( self._beta_coeff * (self._train_steps - self._config.beta_t_start) + self._config.beta_start, self._config.beta_start, ), self._config.beta_end, ) w = (1.0 / self._buffer.size / sample_probs) ** beta w /= w.max() if self._config.use_distributional: updated_weights = loss.detach() else: updated_weights = loss.detach().pow(0.5) loss = (w * loss).mean() self._buffer.update_weights(sample_ids, updated_weights) self._max_error += 0.05 * (updated_weights.max() - self._max_error) else: loss = loss.mean() loss.backward() grad_norm = None if self.config.gradient_norm > 0: grad_norm = nn.utils.clip_grad_norm_( self._network.parameters(), self.config.gradient_norm ) self._optimizer.step() self._train_steps += 1 if self._train_steps % self._config.target_update_steps == 0: self._target_update() if self._logging_client is not None: self._logging_client.log("RainbowAgent/Loss", loss.detach().item()) self._logging_client.log("RainbowAgent/Max error", self._max_error.item()) if grad_norm is not None: self._logging_client.log("RainbowAgent/Gradient norm", grad_norm.item())
class AgentConfig
-
RainbowDQN agent configuration.