Module ai.rl.utils.buffers
Replay buffers.
Expand source code
"""Replay buffers."""
from ._base import Base
from ._uniform import Uniform
from ._weighted import Weighted
__all__ = ["Base", "Uniform", "Weighted"]
Classes
class Base-
Base buffer class. A buffer is a collection of experiences.
Ancestors
- abc.ABC
Subclasses
Instance variables
var capacity : int-
Capacity of the buffer.
Expand source code
@property @abc.abstractmethod def capacity(self) -> int: """Capacity of the buffer.""" raise NotImplementedError var size : int-
Size of the buffer.
Expand source code
@property @abc.abstractmethod def size(self) -> int: """Size of the buffer.""" raise NotImplementedError
Methods
def add(self, data: Tuple[torch.Tensor], weights: torch.Tensor, batch: bool = True) ‑> torch.Tensor-
Adds new data to the buffer.
Args
data:Tuple[torch.Tensor]- Data to be added.
weights:torch.Tensor- Weights of the new data.
batch:optional, bool- If
True, then the first dimension of each tensor is treated as a batch dimension, allowing batch inserts. Defaults toTrue.
Returns
torch.Tensor- Identifier given to the new data.
Expand source code
@abc.abstractmethod def add(self, data: Tuple[torch.Tensor], weights: torch.Tensor, batch: bool=True) -> torch.Tensor: """Adds new data to the buffer. Args: data (Tuple[torch.Tensor]): Data to be added. weights (torch.Tensor): Weights of the new data. batch (optional, bool): If `True`, then the first dimension of each tensor is treated as a batch dimension, allowing batch inserts. Defaults to `True`. Returns: torch.Tensor: Identifier given to the new data. """ raise NotImplementedError def get_all(self) ‑> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]-
Collects and returns all data found in the buffer.
Returns
Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]- Tuple of (data, weights, identifier).
Expand source code
@abc.abstractmethod def get_all(self) -> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: """Collects and returns all data found in the buffer. Returns: Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: Tuple of (data, weights, identifier). """ raise NotImplementedError def sample(self, n: int) ‑> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]-
Collects samples from the buffer.
Args
n:int- Number of samples to collect.
Returns
Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]- Tuple of (data, sample_probabilities, identifier).
Expand source code
@abc.abstractmethod def sample(self, n: int) -> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: """Collects samples from the buffer. Args: n (int): Number of samples to collect. Returns: Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: Tuple of (data, sample_probabilities, identifier). """ raise NotImplementedError def update_weights(self, identifiers: torch.Tensor, weights: torch.Tensor)-
Updates the weights of the given samples.
Args
identifiers:torch.Tensor- Identifiers of the samples whose weights shall be updated.
weights:torch.Tensor- New weights.
Expand source code
@abc.abstractmethod def update_weights(self, identifiers: torch.Tensor, weights: torch.Tensor): """Updates the weights of the given samples. Args: identifiers (torch.Tensor): Identifiers of the samples whose weights shall be updated. weights (torch.Tensor): New weights. """ raise NotImplementedError
class Uniform (capacity: int, shapes: Tuple[Tuple[int, ...]], dtypes: Tuple[torch.dtype], device: torch.device = device(type='cpu'))-
Buffer from which samples are drawn uniformly.
Args
capacity:int- Capacity of the buffer.
shapes:Tuple[Tuple[int, …]]- Shapes of the data to store.
dtypes:Tuple[torch.dtype]- Data types of the data to store.
device:torch.device, optional- Device on which to store the data. Defaults to CPU.
Ancestors
- Base
- abc.ABC
Inherited members
class Weighted (capacity: int, priority_exponent: float, shapes: Tuple[Tuple[int, ...]], dtypes: Tuple[torch.dtype], device: torch.device = device(type='cpu'))-
Buffer where samples are drawn according to their sample weight.
Args
capacity:int- Capacity of the buffer.
priority_exponent:float- Exponent controlling the strictness of the distribution. Larger value implies samples with larger weights are given higher priority. Zero implies uniform sampling. One implies proportional to sample weight.
shapes:Tuple[Tuple[int, …]]- Shapes of the data to store.
dtypes:Tuple[torch.dtype]- Data types of the data to store.
device:torch.device, optional- Device on which to store the data. Defaults to CPU.
Ancestors
- Base
- abc.ABC
Inherited members