提交 f3601225 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba Blockwise: Force scalar inner inputs to be arrays

上级 6f8fc3b6
......@@ -86,6 +86,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
output_bc_patterns,
output_dtypes,
inplace_pattern,
False, # allow_core_scalar
(), # constant_inputs
inputs,
tuple_core_shapes,
......@@ -98,6 +99,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key = None
else:
blockwise_cache_version = 1
blockwise_key = "_".join(
map(
str,
......@@ -108,6 +110,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
blockwise_op.signature,
input_bc_patterns,
core_op_key,
blockwise_cache_version,
),
)
)
......
......@@ -365,6 +365,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
inputs,
core_output_shapes, # core_shapes
......
......@@ -470,6 +470,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
output_bc_patterns,
output_dtypes,
inplace_pattern,
True, # allow_core_scalar
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
......
......@@ -82,6 +82,7 @@ def _vectorized(
output_bc_patterns,
output_dtypes,
inplace_pattern,
allow_core_scalar,
constant_inputs_types,
input_types,
output_core_shape_types,
......@@ -93,6 +94,7 @@ def _vectorized(
output_bc_patterns,
output_dtypes,
inplace_pattern,
allow_core_scalar,
constant_inputs_types,
input_types,
output_core_shape_types,
......@@ -119,6 +121,10 @@ def _vectorized(
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
if not isinstance(allow_core_scalar, types.Literal):
raise TypeError("allow_core_scalar must be literal.")
allow_core_scalar = allow_core_scalar.literal_value
batch_ndim = len(input_bc_patterns[0])
nin = len(constant_inputs_types) + len(input_types)
nout = len(output_bc_patterns)
......@@ -142,8 +148,7 @@ def _vectorized(
core_input_types = []
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
core_ndim = input_type.ndim - len(bc_pattern)
# TODO: Reconsider this
if core_ndim == 0:
if allow_core_scalar and core_ndim == 0:
core_input_type = input_type.dtype
else:
core_input_type = types.Array(
......@@ -196,7 +201,7 @@ def _vectorized(
sig,
args,
):
[_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
[_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
constant_inputs = cgutils.unpack_tuple(builder, constant_inputs)
inputs = cgutils.unpack_tuple(builder, inputs)
......@@ -256,6 +261,7 @@ def _vectorized(
output_bc_patterns_val,
input_types,
output_types,
core_scalar=allow_core_scalar,
)
if len(outputs) == 1:
......@@ -429,6 +435,7 @@ def make_loop_call(
output_bc: tuple[tuple[bool, ...], ...],
input_types: tuple[Any, ...],
output_types: tuple[Any, ...],
core_scalar: bool = True,
):
safe = (False, False)
......@@ -486,7 +493,7 @@ def make_loop_call(
idxs_bc,
*safe,
)
if core_ndim == 0:
if core_scalar and core_ndim == 0:
# Retrive scalar item at index
val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set)
......@@ -499,15 +506,19 @@ def make_loop_call(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
)
core_array = context.make_array(core_arry_type)(context, builder)
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:]
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:]
core_shape = cgutils.unpack_tuple(builder, input.shape)[
input_type.ndim - core_ndim :
]
core_strides = cgutils.unpack_tuple(builder, input.strides)[
input_type.ndim - core_ndim :
]
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
context.populate_array(
core_array,
# TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape),
strides=cgutils.pack_array(builder, core_strides),
shape=core_shape,
strides=core_strides,
itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about?
meminfo=None,
......
......@@ -2,10 +2,12 @@ import numpy as np
import pytest
from pytensor import function
from pytensor.tensor import lvector, tensor, tensor3
from pytensor.graph import Apply
from pytensor.scalar import ScalarOp
from pytensor.tensor import TensorVariable, lvector, tensor, tensor3, vector
from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
......@@ -90,3 +92,39 @@ def test_blockwise_scalar_dimshuffle():
)
out = blockwise_scalar_ds(x)
compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False)
def test_blockwise_vs_elemwise_scalar_op():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1760
class TestScalarOp(ScalarOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outputs):
[x] = inputs
if isinstance(node.inputs[0], TensorVariable):
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, np.number | float)
out = x + 1
if isinstance(node.outputs[0], TensorVariable):
out = np.asarray(out)
outputs[0][0] = out
x = vector("x")
y = Elemwise(TestScalarOp())(x)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run TestScalarOp's perform method",
):
fn = function([x], y, mode="NUMBA")
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])
z = Blockwise(TestScalarOp(), signature="()->()")(x)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run TestScalarOp's perform method",
):
fn = function([x], z, mode="NUMBA")
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论