Module ai.rl.utils

Agent utility methods.

Expand source code
"""Agent utility methods."""

from ._n_step_reward_collector import NStepRewardCollector
from . import buffers


__all__ = ["buffers", "NStepRewardCollector"]

Sub-modules

ai.rl.utils.buffers

Replay buffers.

Classes

class NStepRewardCollector (n_step: int, discount_factor: float, state_data_shapes: Sequence[Tuple[int, ...]], state_data_dtypes: Sequence[torch.dtype], device: torch.device = device(type='cpu'))

Utility object for collecting n-step rewards.

Args

n_step : int
N-step to apply.
discount_factor : float
Discount factor.
state_data_shapes : Sequence[Tuple[int, …]]
Sequence of shapes that need to be stored at each state. These tensors are then paired with the correct state and next states.
state_data_dtypes : Sequence[torch.dtype]
Sequence of data types that need to be stored at each state.
device
(torch.device, optional): Device data is stored on. Defaults to CPU.

Methods

def clear(self)

Clears the content of the collector.

Expand source code
def clear(self):
    """Clears the content of the collector."""
    self._state_data_buffer = tuple(
        state.detach() for state in self._state_data_buffer
    )
    self._rewards = torch.zeros_like(self._rewards)
    self._i = 0
    self._looped = False
def step(self, reward: float, terminal: bool, state_data: Sequence[Union[torch.Tensor, numpy.ndarray]]) ‑> Optional[Tuple[Sequence[torch.Tensor], torch.Tensor, torch.Tensor, Sequence[torch.Tensor]]]

Observes one state transition

Args

reward : float
Reward observed in the 1-step transition.
terminal : bool
If the state transition resulted in a terminal state.
state_data : Sequence[Union[Tensor, ndarray]]
State data of the state that was transitioned from.

Returns

Optional[Tuple[Sequence[Tensor], Tensor, Tensor, Sequence[Tensor]]]
If available: Tuple of state data, rewards, terminals, next state data. Otherwise, None.
Expand source code
def step(
    self,
    reward: float,
    terminal: bool,
    state_data: Sequence[Union[Tensor, ndarray]],
) -> Optional[Tuple[Sequence[Tensor], Tensor, Tensor, Sequence[Tensor]]]:
    """Observes one state transition

    Args:
        reward (float): Reward observed in the 1-step transition.
        terminal (bool): If the state transition resulted in a terminal state.
        state_data (Sequence[Union[Tensor, ndarray]]): State data of the state that
            was transitioned _from_.

    Returns:
        Optional[Tuple[Sequence[Tensor], Tensor, Tensor, Sequence[Tensor]]]: If
            available: Tuple of state data, rewards, terminals, next state data.
            Otherwise, None.
    """

    state_data = [
        torch.as_tensor(data, dtype=dtype, device=self._device)
        for data, dtype in zip(state_data, self._state_data_dtypes)
    ]

    return_state_data = [[] for _ in self._state_data_shapes]
    return_rewards = []
    return_terminals = []
    return_next_state_data = [[] for _ in self._state_data_shapes]

    if self._looped:
        rewards = self._rewards[self._i] * self._discount_vector
        return_rewards.append(rewards.sum().clone().unsqueeze_(0))
        return_terminals.append(torch.tensor([False], device=self._device))
        for x, y in zip(self._state_data_buffer, return_state_data):
            y.append(x[self._i].clone().unsqueeze_(0))
        for x, y in zip(state_data, return_next_state_data):
            y.append(x.clone().unsqueeze_(0))
        self._rewards[self._i] = 0.0

    self._rewards[self._index_vector, (self._i - self._index_vector) % self._n_step] = reward
    for x, y in zip(state_data, self._state_data_buffer):
        y[self._i] = x

    self._i += 1
    if self._i >= self._n_step:
        self._i = 0
        self._looped = True

    if terminal:
        n_return = self._n_step if self._looped else self._i
        rewards = self._rewards[:n_return].clone() * self._discount_vector.view(1, -1)
        return_rewards.append(rewards.sum(1))
        return_terminals.append(torch.ones(n_return, dtype=torch.bool, device=self._device))
        for x, y, z in zip(self._state_data_buffer, return_state_data, return_next_state_data):
            y.append(x[:n_return].clone())
            z.append(torch.zeros_like(x[:n_return]))

        self._looped = False
        self._i = 0
        self._rewards = torch.zeros_like(self._rewards)

    if len(return_rewards) == 0:
        return None

    return (
        tuple(torch.cat(x, dim=0) for x in return_state_data),
        torch.cat(return_rewards, dim=0),
        torch.cat(return_terminals, dim=0),
        tuple(torch.cat(x, dim=0) for x in return_next_state_data)
    )