提交 3cdcfde4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix Blockwise and RandomVariable in Numba with repeated arguments

上级 a920c09f
...@@ -443,6 +443,13 @@ _vectorize_node.register(Blockwise, _vectorize_not_needed) ...@@ -443,6 +443,13 @@ _vectorize_node.register(Blockwise, _vectorize_not_needed)
class OpWithCoreShape(OpFromGraph): class OpWithCoreShape(OpFromGraph):
"""Generalizes an `Op` to include core shape as an additional input.""" """Generalizes an `Op` to include core shape as an additional input."""
def __init__(self, *args, on_unused_input="ignore", **kwargs):
# We set on_unused_inputs="ignore" so that we can easily wrap nodes with repeated inputs
# In this case the subsequent appearance of repeated inputs get disconnected in the inner graph
# I can't think of a scenario where this will backfire, but if there's one
# I bet on inplacing operations (time will tell)
return super().__init__(*args, on_unused_input=on_unused_input, **kwargs)
class BlockwiseWithCoreShape(OpWithCoreShape): class BlockwiseWithCoreShape(OpWithCoreShape):
"""Generalizes a Blockwise `Op` to include a core shape parameter.""" """Generalizes a Blockwise `Op` to include a core shape parameter."""
......
...@@ -2,9 +2,9 @@ import numpy as np ...@@ -2,9 +2,9 @@ import numpy as np
import pytest import pytest
from pytensor import function from pytensor import function
from pytensor.tensor import tensor from pytensor.tensor import tensor, tensor3
from pytensor.tensor.basic import ARange from pytensor.tensor.basic import ARange
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky from pytensor.tensor.slinalg import Cholesky, cholesky
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
...@@ -58,3 +58,15 @@ def test_blockwise_benchmark(benchmark): ...@@ -58,3 +58,15 @@ def test_blockwise_benchmark(benchmark):
x_test = np.eye(3) * np.arange(1, 6)[:, None, None] x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
fn(x_test) # JIT compile fn(x_test) # JIT compile
benchmark(fn, x_test) benchmark(fn, x_test)
def test_repeated_args():
x = tensor3("x")
x_test = np.full((1, 1, 1), 2.0, dtype=x.type.dtype)
out = x @ x
fn, _ = compare_numba_and_py([x], [out], [x_test], eval_obj_mode=False)
# Confirm we are testing a Blockwise with repeated inputs
final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, BlockwiseWithCoreShape)
assert final_node.inputs[0] is final_node.inputs[1]
...@@ -10,6 +10,7 @@ import pytensor.tensor.random.basic as ptr ...@@ -10,6 +10,7 @@ import pytensor.tensor.random.basic as ptr
from pytensor import shared from pytensor import shared
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.tensor.random.op import RandomVariableWithCoreShape
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
numba_mode, numba_mode,
...@@ -693,3 +694,14 @@ def test_rv_inside_ofg(): ...@@ -693,3 +694,14 @@ def test_rv_inside_ofg():
def test_unnatural_batched_dims(batch_dims_tester): def test_unnatural_batched_dims(batch_dims_tester):
"""Tests for RVs that don't have natural batch dims in Numba API.""" """Tests for RVs that don't have natural batch dims in Numba API."""
batch_dims_tester(mode="NUMBA") batch_dims_tester(mode="NUMBA")
def test_repeated_args():
v = pt.scalar()
x = ptr.beta(v, v)
fn, _ = compare_numba_and_py([v], [x], [0.5 * 1e6], eval_obj_mode=False)
# Confirm we are testing a RandomVariable with repeated inputs
final_node = fn.maker.fgraph.outputs[0].owner
assert isinstance(final_node.op, RandomVariableWithCoreShape)
assert final_node.inputs[-2] is final_node.inputs[-1]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论