PyTensor has native support for `pseudo random number generation (PRNG) <https://en.wikipedia.org/wiki/Pseudorandom_number_generator>`_.
This document describes how PRNGs are implemented in PyTensor, via the RandomVariable Operator.
We also discuss how initial seeding and seeding updates are implemented, and some harder cases such as using RandomVariables inside Scan, or with other backends like JAX.
We will use PRNG and RNG interchangeably, keeping in mind we are always talking about PRNGs.
The basics
==========
NumPy
-----
To start off, let's recall how PRNGs works in NumPy
We can see the single node with [id A], has two outputs, which we named next_rng and x. By default only the second output x is given to the user directly, and the other is "hidden".
We can compile a function that returns the next_rng explicitly, so that we can use it as the input of the function in subsequent calls.
>>> f = pytensor.function([rng], [next_rng, x])
>>> rng_val = np.random.default_rng(123)
>>> next_rng_val, x = f(rng_val)
>>> print(x)
[0.68235186 0.05382102]
>>> next_rng_val, x = f(next_rng_val)
>>> print(x)
[0.22035987 0.18437181]
>>> next_rng_val, x = f(next_rng_val)
>>> print(x)
[0.1759059 0.81209451]
Shared variables
================
At this point we can make use of PyTensor shared variables.
Shared variables are global variables that don't need (and can't) be passed as explicit inputs to the functions where they are used.
>>> next_rng, x = pt.random.uniform(rng=rng).owner.outputs
>>>
>>> f = pytensor.function([], [next_rng, x])
>>>
>>> next_rng_val, x = f()
>>> print(x)
0.6823518632481435
We can update the value of shared variables across calls.
>>> rng.set_value(next_rng_val)
>>> next_rng_val, x = f()
>>> print(x)
0.053821018802222675
>>> rng.set_value(next_rng_val)
>>> next_rng_val, x = f()
>>> print(x)
0.22035987277261138
The real benefit of using shared variables is that we can automate this updating via the aptly named updates kwarg of PyTensor functions.
In this case it makes sense to simply replace the original value by the next_rng_val (there is not really any other operation we can do with PyTensor RNGs)
├─ RNG(<Generator(PCG64) at 0x7FA45ED81540>) [id B]
├─ NoneConst{None} [id C]
├─ 0.0 [id D]
└─ 1.0 [id E]
<ipykernel.iostream.OutStream at 0x7fa5d3a475e0>
The destroy map annotation tells us that the first output of the x variable is allowed to alter the first input.
>>> %timeit inplace_f() # doctest: +SKIP
35.5 µs ± 1.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Performance is now much closer to calling numpy directly, with only a small overhead introduced by the PyTensor function.
The `random_make_inplace <https://github.com/pymc-devs/pytensor/blob/3fcf6369d013c597a9c964b2400a3c5e20aa8dce/pytensor/tensor/random/rewriting/basic.py#L42-L52>`_
rewrite automatically replaces RandomVariable Ops by their inplace counterparts, when such operation is deemed safe. This happens when:
#. An input RNG is flagged as `mutable` and is used in not used anywhere else.
#. A RNG is created intermediately and used in not used anywhere else.
The first case is true when a users uses the `mutable` `kwarg` directly, or much more commonly,
when a shared RNG is used and a (default or manual) update expression is given.
In this case, a RandomVariable is allowed to modify the RNG because the shared variable holding it will be rewritten anyway.
The second case is not very common, because RNGs are not usually chained across multiple RandomVariable Ops.
It works, but that graph is slightly unorthodox in Pytensor.
One practical reason is that it is more difficult to define the correct update expression for the shared RNG variable.
One techincal reason is that it makes rewrites more challenging in cases where RandomVariables could otherwise be manipulated independently.
Creating multiple RNG variables
-------------------------------
RandomStreams generate high quality seeds for multiple variables, following the NumPy best practices https://numpy.org/doc/stable/reference/random/parallel.html#parallel-random-number-generation.
Users who create their own RNGs should follow the same practice!
Random variables in inner graphs
================================
Scan
----
Scan works very similar to a function (that is called repeatedly inside an outer scope).
This means that random variables will always return the same output unless updates are specified.
JAX uses a different type of PRNG than those of Numpy. This means that the standard shared RNGs cannot be used directly in graphs transpiled to JAX.
Instead a copy of the Shared RNG variable is made, and its bit generator state is given a jax_state entry that is actually used by the JAX random variables.
In general, update rules are still respected, but they won't be used on the original shared variable, only the copied one actually used in the transpiled function