Module ai.utils.torch.random

Utility methods dealing with random numbers in PyTorch.

Expand source code
"""Utility methods dealing with random numbers in PyTorch."""

from ._choice import choice


__all__ = ["choice"]

Functions

def choice(probabilities: torch.Tensor) ‑> torch.Tensor

Selects elements with probability proportional to the given weights.

Args

probabilities : torch.Tensor
Tensor of shape (…, N). Arguments are sampled over the last dimension according to the weight.

Returns

torch.Tensor
Tensor of shape equal to the input shape, but the last dimension. Data type is torch.long.

Example usage:

x = torch.tensor([[1.0, 1.0, 1.0],
                  [0.0, 1.0, 2.0]])
y = choice(x)   # y[0] is any of (0, 1, 2) with equal probability. y[1] is any of
                # (1, 2), with 2 occuring with twice the probability of 1.
Expand source code
def choice(probabilities: torch.Tensor) -> torch.Tensor:
    """Selects elements with probability proportional to the given weights.

    Args:
        probabilities (torch.Tensor): Tensor of shape `(..., N)`. Arguments are sampled
            over the last dimension according to the weight.

    Returns:
        torch.Tensor: Tensor of shape equal to the input shape, but the last dimension.
            Data type is `torch.long`.

    Example usage:
    ```python
    x = torch.tensor([[1.0, 1.0, 1.0],
                      [0.0, 1.0, 2.0]])
    y = choice(x)   # y[0] is any of (0, 1, 2) with equal probability. y[1] is any of
                    # (1, 2), with 2 occuring with twice the probability of 1.
    ```
    """
    if probabilities.isnan().any():
        raise ValueError("`choice` received NaN values.")
    cumsummed = (probabilities / probabilities.sum(-1, keepdim=True)).cumsum(-1)
    r = torch.rand(probabilities.shape[:-1]).unsqueeze_(-1)
    return (r > cumsummed).sum(-1)