Module ai.rl.utils.buffers
Replay buffers.
Expand source code
"""Replay buffers."""
from ._base import Base
from ._uniform import Uniform
from ._weighted import Weighted
__all__ = ["Base", "Uniform", "Weighted"]
Classes
class Base
-
Base buffer class. A buffer is a collection of experiences.
Ancestors
- abc.ABC
Subclasses
Instance variables
var capacity : int
-
Capacity of the buffer.
Expand source code
@property @abc.abstractmethod def capacity(self) -> int: """Capacity of the buffer.""" raise NotImplementedError
var size : int
-
Size of the buffer.
Expand source code
@property @abc.abstractmethod def size(self) -> int: """Size of the buffer.""" raise NotImplementedError
Methods
def add(self, data: Tuple[torch.Tensor], weights: torch.Tensor, batch: bool = True) ‑> torch.Tensor
-
Adds new data to the buffer.
Args
data
:Tuple[torch.Tensor]
- Data to be added.
weights
:torch.Tensor
- Weights of the new data.
batch
:optional, bool
- If
True
, then the first dimension of each tensor is treated as a batch dimension, allowing batch inserts. Defaults toTrue
.
Returns
torch.Tensor
- Identifier given to the new data.
Expand source code
@abc.abstractmethod def add(self, data: Tuple[torch.Tensor], weights: torch.Tensor, batch: bool=True) -> torch.Tensor: """Adds new data to the buffer. Args: data (Tuple[torch.Tensor]): Data to be added. weights (torch.Tensor): Weights of the new data. batch (optional, bool): If `True`, then the first dimension of each tensor is treated as a batch dimension, allowing batch inserts. Defaults to `True`. Returns: torch.Tensor: Identifier given to the new data. """ raise NotImplementedError
def get_all(self) ‑> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]
-
Collects and returns all data found in the buffer.
Returns
Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]
- Tuple of (data, weights, identifier).
Expand source code
@abc.abstractmethod def get_all(self) -> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: """Collects and returns all data found in the buffer. Returns: Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: Tuple of (data, weights, identifier). """ raise NotImplementedError
def sample(self, n: int) ‑> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]
-
Collects samples from the buffer.
Args
n
:int
- Number of samples to collect.
Returns
Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]
- Tuple of (data, sample_probabilities, identifier).
Expand source code
@abc.abstractmethod def sample(self, n: int) -> Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: """Collects samples from the buffer. Args: n (int): Number of samples to collect. Returns: Tuple[Tuple[torch.Tensor], torch.Tensor, torch.Tensor]: Tuple of (data, sample_probabilities, identifier). """ raise NotImplementedError
def update_weights(self, identifiers: torch.Tensor, weights: torch.Tensor)
-
Updates the weights of the given samples.
Args
identifiers
:torch.Tensor
- Identifiers of the samples whose weights shall be updated.
weights
:torch.Tensor
- New weights.
Expand source code
@abc.abstractmethod def update_weights(self, identifiers: torch.Tensor, weights: torch.Tensor): """Updates the weights of the given samples. Args: identifiers (torch.Tensor): Identifiers of the samples whose weights shall be updated. weights (torch.Tensor): New weights. """ raise NotImplementedError
class Uniform (capacity: int, shapes: Tuple[Tuple[int, ...]], dtypes: Tuple[torch.dtype], device: torch.device = device(type='cpu'))
-
Buffer from which samples are drawn uniformly.
Args
capacity
:int
- Capacity of the buffer.
shapes
:Tuple[Tuple[int, …]]
- Shapes of the data to store.
dtypes
:Tuple[torch.dtype]
- Data types of the data to store.
device
:torch.device
, optional- Device on which to store the data. Defaults to CPU.
Ancestors
- Base
- abc.ABC
Inherited members
class Weighted (capacity: int, priority_exponent: float, shapes: Tuple[Tuple[int, ...]], dtypes: Tuple[torch.dtype], device: torch.device = device(type='cpu'))
-
Buffer where samples are drawn according to their sample weight.
Args
capacity
:int
- Capacity of the buffer.
priority_exponent
:float
- Exponent controlling the strictness of the distribution. Larger value implies samples with larger weights are given higher priority. Zero implies uniform sampling. One implies proportional to sample weight.
shapes
:Tuple[Tuple[int, …]]
- Shapes of the data to store.
dtypes
:Tuple[torch.dtype]
- Data types of the data to store.
device
:torch.device
, optional- Device on which to store the data. Defaults to CPU.
Ancestors
- Base
- abc.ABC
Inherited members