Module ai.rl.utils.buffers

Replay buffers.

Expand source code
"""Replay buffers."""

from ._base import Base
from ._uniform import Uniform
from ._weighted import Weighted


__all__ = ["Base", "Uniform", "Weighted"]

Classes

class Base

Base buffer class. A buffer is a collection of experiences.

Ancestors

  • abc.ABC

Subclasses

Instance variables

var capacity : int

Capacity of the buffer.

Expand source code
@property
@abc.abstractmethod
def capacity(self) -> int:
    """Capacity of the buffer."""
    raise NotImplementedError
var size : int

Size of the buffer.

Expand source code
@property
@abc.abstractmethod
def size(self) -> int:
    """Size of the buffer."""
    raise NotImplementedError

Methods

def add(self, data: Tuple[torch.Tensor], weights: torch.Tensor, batch: bool = True) ‑> torch.Tensor

Adds new data to the buffer.

Args

data : Tuple[torch.Tensor]
Data to be added.
weights : torch.Tensor
Weights of the new data.
batch : optional, bool
If True, then the first dimension of each tensor is treated as a batch dimension, allowing batch inserts. Defaults to True.

Returns

torch.Tensor
Identifier given to the new data.
Expand source code
@abc.abstractmethod
def add(self, data: Tuple[torch.Tensor], weights: torch.Tensor, batch: bool=True) -> torch.Tensor:
    """Adds new data to the buffer.

    Args:
        data (Tuple[torch.Tensor]): Data to be added.
        weights (torch.Tensor): Weights of the new data.
        batch (optional, bool): If `True`, then the first dimension of each tensor
            is treated as a batch dimension, allowing batch inserts. Defaults to
            `True`.

    Returns:
        torch.Tensor: Identifier given to the new data.
    """
    raise NotImplementedError
def get_all(self) ‑> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]

Collects and returns all data found in the buffer.

Returns

Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]
Tuple of (data, weights, identifier).
Expand source code
@abc.abstractmethod
def get_all(self) -> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]:
    """Collects and returns all data found in the buffer.

    Returns:
        Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: Tuple of (data,
            weights, identifier).
    """
    raise NotImplementedError
def sample(self, n: int) ‑> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]

Collects samples from the buffer.

Args

n : int
Number of samples to collect.

Returns

Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]
Tuple of (data, sample_probabilities, identifier).
Expand source code
@abc.abstractmethod
def sample(self, n: int) -> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]:
    """Collects samples from the buffer.

    Args:
        n (int): Number of samples to collect.

    Returns:
        Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: Tuple of (data,
            sample_probabilities, identifier).
    """
    raise NotImplementedError
def update_weights(self, identifiers: torch.Tensor, weights: torch.Tensor)

Updates the weights of the given samples.

Args

identifiers : torch.Tensor
Identifiers of the samples whose weights shall be updated.
weights : torch.Tensor
New weights.
Expand source code
@abc.abstractmethod
def update_weights(self, identifiers: torch.Tensor, weights: torch.Tensor):
    """Updates the weights of the given samples.

    Args:
        identifiers (torch.Tensor): Identifiers of the samples whose weights shall
            be updated.
        weights (torch.Tensor): New weights.
    """
    raise NotImplementedError
class Uniform (capacity: int, shapes: Tuple[Tuple[int, ...]], dtypes: Tuple[torch.dtype], device: torch.device = device(type='cpu'))

Buffer from which samples are drawn uniformly.

Args

capacity : int
Capacity of the buffer.
shapes : Tuple[Tuple[int, …]]
Shapes of the data to store.
dtypes : Tuple[torch.dtype]
Data types of the data to store.
device : torch.device, optional
Device on which to store the data. Defaults to CPU.

Ancestors

Inherited members

class Weighted (capacity: int, priority_exponent: float, shapes: Tuple[Tuple[int, ...]], dtypes: Tuple[torch.dtype], device: torch.device = device(type='cpu'))

Buffer where samples are drawn according to their sample weight.

Args

capacity : int
Capacity of the buffer.
priority_exponent : float
Exponent controlling the strictness of the distribution. Larger value implies samples with larger weights are given higher priority. Zero implies uniform sampling. One implies proportional to sample weight.
shapes : Tuple[Tuple[int, …]]
Shapes of the data to store.
dtypes : Tuple[torch.dtype]
Data types of the data to store.
device : torch.device, optional
Device on which to store the data. Defaults to CPU.

Ancestors

Inherited members