提交 14da898c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add support for RandomVariable with Generators in Numba backend and drop support for RandomState

上级 47874eb9
...@@ -27,7 +27,6 @@ from pytensor.graph.op import HasInnerGraph, Op ...@@ -27,7 +27,6 @@ from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
...@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we # inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will # let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually # need to do it manually
# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature
for inp, inp_shp in zip(inputs, input_shapes): for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.type.ndim: if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim assert len(inp_shp) == inp.type.ndim
...@@ -307,6 +310,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -307,6 +310,7 @@ class OpFromGraph(Op, HasInnerGraph):
connection_pattern: list[list[bool]] | None = None, connection_pattern: list[list[bool]] | None = None,
strict: bool = False, strict: bool = False,
name: str | None = None, name: str | None = None,
destroy_map: dict[int, tuple[int, ...]] | None = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -464,6 +468,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -464,6 +468,7 @@ class OpFromGraph(Op, HasInnerGraph):
if name is not None: if name is not None:
assert isinstance(name, str), "name must be None or string object" assert isinstance(name, str), "name must be None or string object"
self.name = name self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}
def __eq__(self, other): def __eq__(self, other):
# TODO: recognize a copy # TODO: recognize a copy
...@@ -862,6 +867,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -862,6 +867,7 @@ class OpFromGraph(Op, HasInnerGraph):
rop_overrides=self.rop_overrides, rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern, connection_pattern=self._connection_pattern,
name=self.name, name=self.name,
destroy_map=self.destroy_map,
**self.kwargs, **self.kwargs,
) )
new_inputs = ( new_inputs = (
......
...@@ -463,7 +463,7 @@ JAX = Mode( ...@@ -463,7 +463,7 @@ JAX = Mode(
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), NumbaLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(
include=["fast_run"], include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
), ),
) )
......
...@@ -18,6 +18,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 ...@@ -18,6 +18,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload from numba.extending import box, overload
from pytensor import config from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
...@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_OpFromGraph(op, node=None, **kwargs): def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1: if len(op.fgraph.outputs) == 1:
......
...@@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that # TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan? # explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase # The C-code defers it to the make_thunk phase
rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer rewriter = (
op.mode_instance.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
rewriter(op.fgraph) rewriter(op.fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
...@@ -5,6 +5,7 @@ from typing import Any, cast ...@@ -5,6 +5,7 @@ from typing import Any, cast
import numpy as np import numpy as np
from pytensor import config from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
...@@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: ...@@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
_vectorize_node.register(Blockwise, _vectorize_not_needed) _vectorize_node.register(Blockwise, _vectorize_not_needed)
class OpWithCoreShape(OpFromGraph):
"""Generalizes an `Op` to include core shape as an additional input."""
...@@ -2082,10 +2082,7 @@ def choice(a, size=None, replace=True, p=None, rng=None): ...@@ -2082,10 +2082,7 @@ def choice(a, size=None, replace=True, p=None, rng=None):
# This is equivalent to the numpy implementation: # This is equivalent to the numpy implementation:
# https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914 # https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
if p is None: if p is None:
if rng is not None and isinstance(rng.type, RandomStateType): idxs = integers(0, a_size, size=size, rng=rng)
idxs = randint(0, a_size, size=size, rng=rng)
else:
idxs = integers(0, a_size, size=size, rng=rng)
else: else:
idxs = categorical(p, size=size, rng=rng) idxs = categorical(p, size=size, rng=rng)
......
...@@ -19,6 +19,7 @@ from pytensor.tensor.basic import ( ...@@ -19,6 +19,7 @@ from pytensor.tensor.basic import (
get_vector_length, get_vector_length,
infer_static_shape, infer_static_shape,
) )
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
compute_batch_shape, compute_batch_shape,
...@@ -476,3 +477,11 @@ def vectorize_random_variable( ...@@ -476,3 +477,11 @@ def vectorize_random_variable(
size = concatenate([new_size_dims, size]) size = concatenate([new_size_dims, size])
return op.make_node(rng, size, *dist_params) return op.make_node(rng, size, *dist_params)
class RandomVariableWithCoreShape(OpWithCoreShape):
"""Generalizes a random variable `Op` to include a core shape parameter."""
def __str__(self):
[rv_node] = self.fgraph.apply_nodes
return f"[{rv_node.op!s}]"
...@@ -4,7 +4,8 @@ from pytensor.tensor.random.rewriting.basic import * ...@@ -4,7 +4,8 @@ from pytensor.tensor.random.rewriting.basic import *
# isort: off # isort: off
# Register JAX specializations # Register Numba and JAX specializations
import pytensor.tensor.random.rewriting.numba
import pytensor.tensor.random.rewriting.jax import pytensor.tensor.random.rewriting.jax
# isort: on # isort: on
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.tensor import as_tensor, constant
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature
@node_rewriter([RandomVariable])
def introduce_explicit_core_shape_rv(fgraph, node):
"""Introduce the core shape of a RandomVariable.
We wrap RandomVariable graphs into a RandomVariableWithCoreShape OpFromGraph
that has an extra "non-functional" input that represents the core shape of the random variable.
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher change of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of RandomVariable.
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
x = pt.random.dirichlet(alphas=[1, 2, 3], size=(5,))
pytensor.dprint(x, print_type=True)
# dirichlet_rv{"(a)->(a)"}.1 [id A] <Matrix(float64, shape=(5, 3))>
# ├─ RNG(<Generator(PCG64) at 0x7F09E59C18C0>) [id B] <RandomGeneratorType>
# ├─ [5] [id C] <Vector(int64, shape=(1,))>
# └─ ExpandDims{axis=0} [id D] <Matrix(int64, shape=(1, 3))>
# └─ [1 2 3] [id E] <Vector(int64, shape=(3,))>
# After the rewrite, note the new core shape input [3] [id B]
fn = pytensor.function([], x, mode="NUMBA")
pytensor.dprint(fn.maker.fgraph)
# [dirichlet_rv{"(a)->(a)"}].1 [id A] 0
# ├─ [3] [id B]
# ├─ RNG(<Generator(PCG64) at 0x7F15B8E844A0>) [id C]
# ├─ [5] [id D]
# └─ [[1 2 3]] [id E]
# Inner graphs:
# [dirichlet_rv{"(a)->(a)"}] [id A]
# ← dirichlet_rv{"(a)->(a)"}.0 [id F]
# ├─ *1-<RandomGeneratorType> [id G]
# ├─ *2-<Vector(int64, shape=(1,))> [id H]
# └─ *3-<Matrix(int64, shape=(1, 3))> [id I]
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ···
"""
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
if shape_feature:
core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
]
else:
core_shape = op._supp_shape_from_params(op.dist_params(node))
if len(core_shape) == 0:
core_shape = constant([], dtype="int64")
else:
core_shape = as_tensor(core_shape)
return (
RandomVariableWithCoreShape(
[core_shape, *node.inputs],
node.outputs,
destroy_map={0: [1]} if op.inplace else None,
)
.make_node(core_shape, *node.inputs)
.outputs
)
optdb.register(
introduce_explicit_core_shape_rv.__name__,
out2in(introduce_explicit_core_shape_rv),
"numba",
position=100,
)
...@@ -740,13 +740,13 @@ class UnShapeOptimizer(GraphRewriter): ...@@ -740,13 +740,13 @@ class UnShapeOptimizer(GraphRewriter):
# Register it after merge1 optimization at 0. We don't want to track # Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node. # the shape of merged node.
pytensor.compile.mode.optdb.register( # type: ignore pytensor.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1 "ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
) )
# Not enabled by default for now. Some crossentropy opt use the # Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step # shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable. # 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) # type: ignore pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op): def local_reshape_chain(op):
......
...@@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py ...@@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py
pytensor/ifelse.py pytensor/ifelse.py
pytensor/link/basic.py pytensor/link/basic.py
pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/elemwise.py
pytensor/link/numba/dispatch/random.py
pytensor/link/numba/dispatch/scan.py pytensor/link/numba/dispatch/scan.py
pytensor/printing.py pytensor/printing.py
pytensor/raise_op.py pytensor/raise_op.py
......
...@@ -29,7 +29,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -29,7 +29,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.scalar.basic import ScalarOp, as_scalar
...@@ -120,7 +119,7 @@ my_multi_out.ufunc = MyMultiOut.impl ...@@ -120,7 +119,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2 my_multi_out.ufunc.nout = 2
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts.including("numba"))
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -229,6 +228,7 @@ def compare_numba_and_py( ...@@ -229,6 +228,7 @@ def compare_numba_and_py(
numba_mode=numba_mode, numba_mode=numba_mode,
py_mode=py_mode, py_mode=py_mode,
updates=None, updates=None,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]: ) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
...@@ -247,6 +247,8 @@ def compare_numba_and_py( ...@@ -247,6 +247,8 @@ def compare_numba_and_py(
provided uses `np.testing.assert_allclose`. provided uses `np.testing.assert_allclose`.
updates updates
Updates to be passed to `pytensor.function`. Updates to be passed to `pytensor.function`.
eval_obj_mode : bool, default True
Whether to do an isolated call in object mode. Used for test coverage
Returns Returns
------- -------
...@@ -283,7 +285,8 @@ def compare_numba_and_py( ...@@ -283,7 +285,8 @@ def compare_numba_and_py(
numba_res = pytensor_numba_fn(*inputs) numba_res = pytensor_numba_fn(*inputs)
# Get some coverage # Get some coverage
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) if eval_obj_mode:
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fn_outputs) > 1: if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res): for j, p in zip(numba_res, py_res):
...@@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected assert res == expected
@pytest.mark.parametrize(
"input, wrapper_fn, check_fn",
[
(
np.random.RandomState(1),
numba_typify,
lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]),
)
],
)
def test_box_unbox(input, wrapper_fn, check_fn):
input = wrapper_fn(input)
pass_through = numba.njit(lambda x: x)
res = pass_through(input)
assert isinstance(res, type(input))
assert check_fn(res, input)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
......
...@@ -77,15 +77,13 @@ from tests.link.numba.test_basic import compare_numba_and_py ...@@ -77,15 +77,13 @@ from tests.link.numba.test_basic import compare_numba_and_py
), ),
# nit-sot, shared input/output # nit-sot, shared input/output
( (
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal( lambda: RandomStream(seed=1930).normal(0, 1, name="a"),
0, 1, name="a"
),
[], [],
[{}], [{}],
[], [],
3, 3,
[], [],
[np.array([-1.63408257, 0.18046406, 2.43265803])], [np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0, lambda op: op.info.n_shared_outs > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
......
...@@ -1452,9 +1452,7 @@ def test_permutation_shape(): ...@@ -1452,9 +1452,7 @@ def test_permutation_shape():
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5) assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
def batched_unweighted_choice_without_replacement_tester( def batched_unweighted_choice_without_replacement_tester(mode="FAST_RUN"):
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
"""Test unweighted choice without replacement with batched ndims. """Test unweighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the This has no corresponding in numpy, but is supported for consistency within the
...@@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization. It can be triggered by manual buiding the Op or during automatic vectorization.
""" """
rng = shared(rng_ctor()) rng = shared(np.random.default_rng())
# Batched a implicit size # Batched a implicit size
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
...@@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10)) assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
def batched_weighted_choice_without_replacement_tester( def batched_weighted_choice_without_replacement_tester(mode="FAST_RUN"):
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
"""Test weighted choice without replacement with batched ndims. """Test weighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the This has no corresponding in numpy, but is supported for consistency within the
...@@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization. It can be triggered by manual buiding the Op or during automatic vectorization.
""" """
rng = shared(rng_ctor()) rng = shared(np.random.default_rng())
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
signature="(a0,a1),(a0),(1)->(s0,a1)", signature="(a0,a1),(a0),(1)->(s0,a1)",
...@@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10)) assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10))
def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng): def batched_permutation_tester(mode="FAST_RUN"):
"""Test permutation with batched ndims. """Test permutation with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the This has no corresponding in numpy, but is supported for consistency within the
...@@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng): ...@@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
It can be triggered by manual buiding the Op or during automatic vectorization. It can be triggered by manual buiding the Op or during automatic vectorization.
""" """
rng = shared(rng_ctor()) rng = shared(np.random.default_rng())
rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64") rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64")
x = np.arange(5 * 3 * 2).reshape((5, 3, 2)) x = np.arange(5 * 3 * 2).reshape((5, 3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论