提交 14da898c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add support for RandomVariable with Generators in Numba backend and drop support for RandomState

上级 47874eb9
......@@ -27,7 +27,6 @@ from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes):
......@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually
# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature
for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim
......@@ -307,6 +310,7 @@ class OpFromGraph(Op, HasInnerGraph):
connection_pattern: list[list[bool]] | None = None,
strict: bool = False,
name: str | None = None,
destroy_map: dict[int, tuple[int, ...]] | None = None,
**kwargs,
):
"""
......@@ -464,6 +468,7 @@ class OpFromGraph(Op, HasInnerGraph):
if name is not None:
assert isinstance(name, str), "name must be None or string object"
self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}
def __eq__(self, other):
# TODO: recognize a copy
......@@ -862,6 +867,7 @@ class OpFromGraph(Op, HasInnerGraph):
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
destroy_map=self.destroy_map,
**self.kwargs,
)
new_inputs = (
......
......@@ -463,7 +463,7 @@ JAX = Mode(
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
......
......@@ -18,6 +18,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload
from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply
......@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
......
......@@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase
rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer
rewriter = (
op.mode_instance.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
rewriter(op.fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
......@@ -5,6 +5,7 @@ from typing import Any, cast
import numpy as np
from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType
......@@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
_vectorize_node.register(Blockwise, _vectorize_not_needed)
class OpWithCoreShape(OpFromGraph):
"""Generalizes an `Op` to include core shape as an additional input."""
......@@ -2082,10 +2082,7 @@ def choice(a, size=None, replace=True, p=None, rng=None):
# This is equivalent to the numpy implementation:
# https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
if p is None:
if rng is not None and isinstance(rng.type, RandomStateType):
idxs = randint(0, a_size, size=size, rng=rng)
else:
idxs = integers(0, a_size, size=size, rng=rng)
idxs = integers(0, a_size, size=size, rng=rng)
else:
idxs = categorical(p, size=size, rng=rng)
......
......@@ -19,6 +19,7 @@ from pytensor.tensor.basic import (
get_vector_length,
infer_static_shape,
)
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
......@@ -476,3 +477,11 @@ def vectorize_random_variable(
size = concatenate([new_size_dims, size])
return op.make_node(rng, size, *dist_params)
class RandomVariableWithCoreShape(OpWithCoreShape):
"""Generalizes a random variable `Op` to include a core shape parameter."""
def __str__(self):
[rv_node] = self.fgraph.apply_nodes
return f"[{rv_node.op!s}]"
......@@ -4,7 +4,8 @@ from pytensor.tensor.random.rewriting.basic import *
# isort: off
# Register JAX specializations
# Register Numba and JAX specializations
import pytensor.tensor.random.rewriting.numba
import pytensor.tensor.random.rewriting.jax
# isort: on
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.tensor import as_tensor, constant
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature
@node_rewriter([RandomVariable])
def introduce_explicit_core_shape_rv(fgraph, node):
"""Introduce the core shape of a RandomVariable.
We wrap RandomVariable graphs into a RandomVariableWithCoreShape OpFromGraph
that has an extra "non-functional" input that represents the core shape of the random variable.
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher change of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of RandomVariable.
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
x = pt.random.dirichlet(alphas=[1, 2, 3], size=(5,))
pytensor.dprint(x, print_type=True)
# dirichlet_rv{"(a)->(a)"}.1 [id A] <Matrix(float64, shape=(5, 3))>
# ├─ RNG(<Generator(PCG64) at 0x7F09E59C18C0>) [id B] <RandomGeneratorType>
# ├─ [5] [id C] <Vector(int64, shape=(1,))>
# └─ ExpandDims{axis=0} [id D] <Matrix(int64, shape=(1, 3))>
# └─ [1 2 3] [id E] <Vector(int64, shape=(3,))>
# After the rewrite, note the new core shape input [3] [id B]
fn = pytensor.function([], x, mode="NUMBA")
pytensor.dprint(fn.maker.fgraph)
# [dirichlet_rv{"(a)->(a)"}].1 [id A] 0
# ├─ [3] [id B]
# ├─ RNG(<Generator(PCG64) at 0x7F15B8E844A0>) [id C]
# ├─ [5] [id D]
# └─ [[1 2 3]] [id E]
# Inner graphs:
# [dirichlet_rv{"(a)->(a)"}] [id A]
# ← dirichlet_rv{"(a)->(a)"}.0 [id F]
# ├─ *1-<RandomGeneratorType> [id G]
# ├─ *2-<Vector(int64, shape=(1,))> [id H]
# └─ *3-<Matrix(int64, shape=(1, 3))> [id I]
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ···
"""
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
if shape_feature:
core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
]
else:
core_shape = op._supp_shape_from_params(op.dist_params(node))
if len(core_shape) == 0:
core_shape = constant([], dtype="int64")
else:
core_shape = as_tensor(core_shape)
return (
RandomVariableWithCoreShape(
[core_shape, *node.inputs],
node.outputs,
destroy_map={0: [1]} if op.inplace else None,
)
.make_node(core_shape, *node.inputs)
.outputs
)
optdb.register(
introduce_explicit_core_shape_rv.__name__,
out2in(introduce_explicit_core_shape_rv),
"numba",
position=100,
)
......@@ -740,13 +740,13 @@ class UnShapeOptimizer(GraphRewriter):
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
pytensor.compile.mode.optdb.register( # type: ignore
pytensor.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) # type: ignore
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op):
......
......@@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py
pytensor/ifelse.py
pytensor/link/basic.py
pytensor/link/numba/dispatch/elemwise.py
pytensor/link/numba/dispatch/random.py
pytensor/link/numba/dispatch/scan.py
pytensor/printing.py
pytensor/raise_op.py
......
......@@ -29,7 +29,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
......@@ -120,7 +119,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
numba_mode = Mode(NumbaLinker(), opts.including("numba"))
py_mode = Mode("py", opts)
rng = np.random.default_rng(42849)
......@@ -229,6 +228,7 @@ def compare_numba_and_py(
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality
......@@ -247,6 +247,8 @@ def compare_numba_and_py(
provided uses `np.testing.assert_allclose`.
updates
Updates to be passed to `pytensor.function`.
eval_obj_mode : bool, default True
Whether to do an isolated call in object mode. Used for test coverage
Returns
-------
......@@ -283,7 +285,8 @@ def compare_numba_and_py(
numba_res = pytensor_numba_fn(*inputs)
# Get some coverage
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if eval_obj_mode:
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res):
......@@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected
@pytest.mark.parametrize(
"input, wrapper_fn, check_fn",
[
(
np.random.RandomState(1),
numba_typify,
lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]),
)
],
)
def test_box_unbox(input, wrapper_fn, check_fn):
input = wrapper_fn(input)
pass_through = numba.njit(lambda x: x)
res = pass_through(input)
assert isinstance(res, type(input))
assert check_fn(res, input)
@pytest.mark.parametrize(
"x, indices",
[
......
......@@ -77,15 +77,13 @@ from tests.link.numba.test_basic import compare_numba_and_py
),
# nit-sot, shared input/output
(
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
lambda: RandomStream(seed=1930).normal(0, 1, name="a"),
[],
[{}],
[],
3,
[],
[np.array([-1.63408257, 0.18046406, 2.43265803])],
[np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
......
......@@ -1452,9 +1452,7 @@ def test_permutation_shape():
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
def batched_unweighted_choice_without_replacement_tester(
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
def batched_unweighted_choice_without_replacement_tester(mode="FAST_RUN"):
"""Test unweighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
......@@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rng = shared(np.random.default_rng())
# Batched a implicit size
rv_op = ChoiceWithoutReplacement(
......@@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
def batched_weighted_choice_without_replacement_tester(
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
def batched_weighted_choice_without_replacement_tester(mode="FAST_RUN"):
"""Test weighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
......@@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rng = shared(np.random.default_rng())
rv_op = ChoiceWithoutReplacement(
signature="(a0,a1),(a0),(1)->(s0,a1)",
......@@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10))
def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
def batched_permutation_tester(mode="FAST_RUN"):
"""Test permutation with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
......@@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rng = shared(np.random.default_rng())
rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64")
x = np.arange(5 * 3 * 2).reshape((5, 3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论