Module ai.rl.dqn.rainbow

Rainbow DQN.

Expand source code
"""Rainbow DQN."""


from ._agent import Agent
from ._agent_config import AgentConfig
from . import networks, trainers


__all__ = ["Agent", "AgentConfig", "networks", "trainers"]

Sub-modules

ai.rl.dqn.rainbow.networks

Pre defined networks for RainbowDQN.

ai.rl.dqn.rainbow.trainers

RainbowDQN trainers.

Classes

class Agent (config: AgentConfig, network: Factory[torch.nn.modules.module.Module], optimizer: Factory[torch.optim.optimizer.Optimizer] = None, inference_mode: bool = False, replay_init_lazily: bool = True)

RainbowDQN Agent.

Args

config : AgentConfig
Agent configuration.
network : Factory[nn.Module]
Network wrapped in a Factory.
optimizer : Factory[optim.Optimizer]
Optimizer, wrapped in a Factory. Model parameters are passed to the optimizer when instanced.
inference_mode : bool, optional
If True, the agent can only be used for acting. Saves memory by not initializing a replay buffer. Defaults to False.
replay_init_lazily : bool, optional
If True, the replay buffer is initialized lazily, i.e. when the first observation is added. Defaults to True.

Instance variables

var configAgentConfig

Agent configuration in use.

Expand source code
@property
def config(self) -> AgentConfig:
    """Agent configuration in use."""
    return self._config
var model_factoryFactory[torch.nn.modules.module.Module]

Model factory used by the agent.

Expand source code
@property
def model_factory(self) -> Factory[nn.Module]:
    """Model factory used by the agent."""
    return self._network_factory
var model_instance : torch.nn.modules.module.Module

Model instance.

Expand source code
@property
def model_instance(self) -> nn.Module:
    """Model instance."""
    return self._network

Methods

def act(self, states: Union[numpy.ndarray, torch.Tensor], action_masks: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor

Returns the greedy action for the given states and action masks.

Args

states : Union[Tensor, ndarray]
States.
action_masks : Union[Tensor, ndarray]
Action masks.

Returns

Tensor
Tensor of dtype torch.long.
Expand source code
def act(
    self, states: Union[Tensor, ndarray], action_masks: Union[Tensor, ndarray]
) -> Tensor:
    """Returns the greedy action for the given states and action masks.

    Args:
        states (Union[Tensor, ndarray]): States.
        action_masks (Union[Tensor, ndarray]): Action masks.

    Returns:
        Tensor: Tensor of dtype `torch.long`.
    """
    with torch.no_grad():
        return _get_actions(
            torch.as_tensor(
                action_masks, dtype=torch.bool, device=self._config.network_device
            ),
            self._network(
                torch.as_tensor(
                    states, dtype=torch.float32, device=self._config.network_device
                )
            ),
            self._config.use_distributional,
            self._z,
        ).cpu()
def act_single(self, state: Union[numpy.ndarray, torch.Tensor], action_mask: Union[numpy.ndarray, torch.Tensor]) ‑> int

Returns the greedy action for one state-action mask pair.

Args

state : Union[Tensor, ndarray]
State.
action_mask : Union[Tensor, ndarray]
Action mask.

Returns

int
Action index.
Expand source code
def act_single(
    self, state: Union[Tensor, ndarray], action_mask: Union[Tensor, ndarray]
) -> int:
    """Returns the greedy action for one state-action mask pair.

    Args:
        state (Union[Tensor, ndarray]): State.
        action_mask (Union[Tensor, ndarray]): Action mask.

    Returns:
        int: Action index.
    """
    return self.act(
        torch.as_tensor(
            state, dtype=torch.float32, device=self._config.network_device
        ).unsqueeze_(0),
        torch.as_tensor(
            action_mask, dtype=torch.bool, device=self._config.network_device
        ).unsqueeze_(0),
    )[0]
def buffer_size(self) ‑> int

Returns

int
The current size of the replay buffer.
Expand source code
def buffer_size(self) -> int:
    """
    Returns:
        int: The current size of the replay buffer.
    """
    return 0 if self._buffer is None else self._buffer.size
def inference_mode(self) ‑> Agent

Returns a copy of this agent, but in inference mode.

Returns

Agent
A shallow copy of the agent, capable of inference only.
Expand source code
def inference_mode(self) -> "Agent":
    """Returns a copy of this agent, but in inference mode.
    
    Returns:
        Agent: A shallow copy of the agent, capable of inference only."""
    return Agent(self._config, self._network, inference_mode=True)
def observe(self, states: Union[numpy.ndarray, torch.Tensor], actions: Union[numpy.ndarray, torch.Tensor], rewards: Union[numpy.ndarray, torch.Tensor], terminals: Union[numpy.ndarray, torch.Tensor], next_states: Union[numpy.ndarray, torch.Tensor], next_action_masks: Union[numpy.ndarray, torch.Tensor], errors: Union[numpy.ndarray, torch.Tensor])

Adds a batch of experiences to the replay

Args

states : Union[Tensor, ndarray]
States
actions : Union[Tensor, ndarray]
Actions
rewards : Union[Tensor, ndarray]
Rewards
terminals : Union[Tensor, ndarray]
Terminal flags
next_states : Union[Tensor, ndarray]
Next states
next_action_masks : Union[Tensor, ndarray]
Next action masks
errors : Union[Tensor, ndarray]
TD errors. NaN values are replaced by appropriate initialization value.
Expand source code
def observe(
    self,
    states: Union[Tensor, ndarray],
    actions: Union[Tensor, ndarray],
    rewards: Union[Tensor, ndarray],
    terminals: Union[Tensor, ndarray],
    next_states: Union[Tensor, ndarray],
    next_action_masks: Union[Tensor, ndarray],
    errors: Union[Tensor, ndarray],
):
    """Adds a batch of experiences to the replay

    Args:
        states (Union[Tensor, ndarray]): States
        actions (Union[Tensor, ndarray]): Actions
        rewards (Union[Tensor, ndarray]): Rewards
        terminals (Union[Tensor, ndarray]): Terminal flags
        next_states (Union[Tensor, ndarray]): Next states
        next_action_masks (Union[Tensor, ndarray]): Next action masks
        errors (Union[Tensor, ndarray]): TD errors. NaN values are replaced by
            appropriate initialization value.
    """
    if self._buffer is None:
        self._initialize_replay(self._config)

    errors = torch.as_tensor(errors, dtype=torch.float32)
    errors[errors.isnan()] = self._max_error
    self._buffer.add(
        (
            torch.as_tensor(
                states, dtype=torch.float32, device=self._config.replay_device
            ),
            torch.as_tensor(
                actions, dtype=torch.long, device=self._config.replay_device
            ),
            torch.as_tensor(
                rewards, dtype=torch.float32, device=self._config.replay_device
            ),
            torch.as_tensor(
                terminals, dtype=torch.bool, device=self._config.replay_device
            ),
            torch.as_tensor(
                next_states, dtype=torch.float32, device=self._config.replay_device
            ),
            torch.as_tensor(
                next_action_masks,
                dtype=torch.bool,
                device=self._config.replay_device,
            ),
        ),
        errors,
    )
def observe_single(self, state: Union[numpy.ndarray, torch.Tensor], action: int, reward: float, terminal: bool, next_state: Union[numpy.ndarray, torch.Tensor], next_action_mask: Union[numpy.ndarray, torch.Tensor], error: float)

Adds a single experience to the replay buffer.

Args

state : Union[Tensor, ndarray]
State
action : int
Action
reward : float
Reward
terminal : bool
True if next_state is a terminal state
next_state : Union[Tensor, ndarray]
Next state
next_action_mask : Union[Tensor, ndarray]
Next action mask
error : float
TD error. NaN values are replaced by appropriate initialization value.
Expand source code
def observe_single(
    self,
    state: Union[Tensor, ndarray],
    action: int,
    reward: float,
    terminal: bool,
    next_state: Union[Tensor, ndarray],
    next_action_mask: Union[Tensor, ndarray],
    error: float,
):
    """Adds a single experience to the replay buffer.

    Args:
        state (Union[Tensor, ndarray]): State
        action (int): Action
        reward (float): Reward
        terminal (bool): True if `next_state` is a terminal state
        next_state (Union[Tensor, ndarray]): Next state
        next_action_mask (Union[Tensor, ndarray]): Next action mask
        error (float): TD error. NaN values are replaced by appropriate
            initialization value.
    """
    self.observe(
        torch.as_tensor(
            state, dtype=torch.float32, device=self._config.replay_device
        ).unsqueeze_(0),
        torch.tensor([action], dtype=torch.long, device=self._config.replay_device),
        torch.tensor(
            [reward], dtype=torch.float32, device=self._config.replay_device
        ),
        torch.tensor(
            [terminal], dtype=torch.bool, device=self._config.replay_device
        ),
        torch.as_tensor(
            next_state, dtype=torch.float32, device=self._config.replay_device
        ).unsqueeze_(0),
        torch.as_tensor(
            next_action_mask, dtype=torch.bool, device=self._config.replay_device
        ).unsqueeze_(0),
        torch.tensor(
            [error], dtype=torch.float32, device=self._config.replay_device
        ),
    )
def q_values(self, states: Union[numpy.ndarray, torch.Tensor], action_masks: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor

Computes the Q-values for each state-action pair. Illegal actions are given a values -inf.

Args

states : Union[Tensor, ndarray]
States.
action_masks : Union[Tensor, ndarray]
Action masks.

Returns

Tensor
Q-values.
Expand source code
def q_values(
    self, states: Union[Tensor, ndarray], action_masks: Union[Tensor, ndarray]
) -> Tensor:
    """Computes the Q-values for each state-action pair. Illegal actions are given
    a values -inf.

    Args:
        states (Union[Tensor, ndarray]): States.
        action_masks (Union[Tensor, ndarray]): Action masks.

    Returns:
        Tensor: Q-values.
    """

    states = torch.as_tensor(
        states, dtype=torch.float32, device=self._config.network_device
    )
    action_masks = torch.as_tensor(
        action_masks, dtype=torch.bool, device=self._config.network_device
    )

    with torch.no_grad():
        if self.config.use_distributional:
            values = (self._z.view(1, 1, -1) * self._network(states)).sum(2)
        else:
            values = self._network(states)
        return _apply_masks(values, action_masks).cpu()
def q_values_single(self, state: Union[numpy.ndarray, torch.Tensor], action_mask: Union[numpy.ndarray, torch.Tensor]) ‑> torch.Tensor

Computes the Q-values of all state-action pairs, given a single state.

Args

state : Union[Tensor, ndarray]
State.
action_mask : Union[Tensor, ndarray]
Action mask.

Returns

Tensor
Q-values, illegal actions have value -inf.
Expand source code
def q_values_single(
    self, state: Union[Tensor, ndarray], action_mask: Union[Tensor, ndarray]
) -> Tensor:
    """Computes the Q-values of all state-action pairs, given a single state.

    Args:
        state (Union[Tensor, ndarray]): State.
        action_mask (Union[Tensor, ndarray]): Action mask.

    Returns:
        Tensor: Q-values, illegal actions have value -inf.
    """
    return self.q_values(
        torch.as_tensor(
            state, dtype=torch.float32, device=self._config.network_device
        ).unsqueeze_(0),
        torch.as_tensor(
            action_mask, dtype=torch.bool, device=self._config.network_device
        ).unsqueeze_(0),
    )[0]
def set_logging_client(self, client: Client)

Sets (and overrides previously set) logging client used by the agent. The agent outputs tensorboard logs through this client.

Args

client : logging.Client
Client.
Expand source code
def set_logging_client(self, client: logging.Client):
    """Sets (and overrides previously set) logging client used by the agent. The
    agent outputs tensorboard logs through this client.

    Args:
        client (logging.Client): Client.
    """
    self._logging_client = client
def train_step(self)

Executes one training step.

Expand source code
def train_step(self):
    """Executes one training step."""

    (
        data,
        sample_probs,
        sample_ids,
    ) = self._buffer.sample(self._config.batch_size)

    data = (x.to(self._config.network_device) for x in data)

    if self._config.use_distributional:
        loss = self._get_distributional_loss(*data)
    else:
        loss = self._get_td_loss(*data)

    self._optimizer.zero_grad()
    if self._config.use_prioritized_experience_replay:
        beta = min(
            max(
                self._beta_coeff * (self._train_steps - self._config.beta_t_start)
                + self._config.beta_start,
                self._config.beta_start,
            ),
            self._config.beta_end,
        )
        w = (1.0 / self._buffer.size / sample_probs) ** beta
        w /= w.max()
        if self._config.use_distributional:
            updated_weights = loss.detach()
        else:
            updated_weights = loss.detach().pow(0.5)
        loss = (w * loss).mean()
        self._buffer.update_weights(sample_ids, updated_weights)
        self._max_error += 0.05 * (updated_weights.max() - self._max_error)
    else:
        loss = loss.mean()
    loss.backward()
    grad_norm = None
    if self.config.gradient_norm > 0:
        grad_norm = nn.utils.clip_grad_norm_(
            self._network.parameters(), self.config.gradient_norm
        )
    self._optimizer.step()

    self._train_steps += 1
    if self._train_steps % self._config.target_update_steps == 0:
        self._target_update()

    if self._logging_client is not None:
        self._logging_client.log("RainbowAgent/Loss", loss.detach().item())
        self._logging_client.log("RainbowAgent/Max error", self._max_error.item())

        if grad_norm is not None:
            self._logging_client.log("RainbowAgent/Gradient norm", grad_norm.item())
class AgentConfig

RainbowDQN agent configuration.