提交 f3601225 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba Blockwise: Force scalar inner inputs to be arrays

上级 6f8fc3b6
...@@ -86,6 +86,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -86,6 +86,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
inplace_pattern, inplace_pattern,
False, # allow_core_scalar
(), # constant_inputs (), # constant_inputs
inputs, inputs,
tuple_core_shapes, tuple_core_shapes,
...@@ -98,6 +99,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -98,6 +99,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either # If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key = None blockwise_key = None
else: else:
blockwise_cache_version = 1
blockwise_key = "_".join( blockwise_key = "_".join(
map( map(
str, str,
...@@ -108,6 +110,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): ...@@ -108,6 +110,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
blockwise_op.signature, blockwise_op.signature,
input_bc_patterns, input_bc_patterns,
core_op_key, core_op_key,
blockwise_cache_version,
), ),
) )
) )
......
...@@ -365,6 +365,7 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -365,6 +365,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
output_bc_patterns_enc, output_bc_patterns_enc,
output_dtypes_enc, output_dtypes_enc,
inplace_pattern_enc, inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs (), # constant_inputs
inputs, inputs,
core_output_shapes, # core_shapes core_output_shapes, # core_shapes
......
...@@ -470,6 +470,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -470,6 +470,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
inplace_pattern, inplace_pattern,
True, # allow_core_scalar
(rng,), (rng,),
dist_params, dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
......
...@@ -82,6 +82,7 @@ def _vectorized( ...@@ -82,6 +82,7 @@ def _vectorized(
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
inplace_pattern, inplace_pattern,
allow_core_scalar,
constant_inputs_types, constant_inputs_types,
input_types, input_types,
output_core_shape_types, output_core_shape_types,
...@@ -93,6 +94,7 @@ def _vectorized( ...@@ -93,6 +94,7 @@ def _vectorized(
output_bc_patterns, output_bc_patterns,
output_dtypes, output_dtypes,
inplace_pattern, inplace_pattern,
allow_core_scalar,
constant_inputs_types, constant_inputs_types,
input_types, input_types,
output_core_shape_types, output_core_shape_types,
...@@ -119,6 +121,10 @@ def _vectorized( ...@@ -119,6 +121,10 @@ def _vectorized(
inplace_pattern = inplace_pattern.literal_value inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
if not isinstance(allow_core_scalar, types.Literal):
raise TypeError("allow_core_scalar must be literal.")
allow_core_scalar = allow_core_scalar.literal_value
batch_ndim = len(input_bc_patterns[0]) batch_ndim = len(input_bc_patterns[0])
nin = len(constant_inputs_types) + len(input_types) nin = len(constant_inputs_types) + len(input_types)
nout = len(output_bc_patterns) nout = len(output_bc_patterns)
...@@ -142,8 +148,7 @@ def _vectorized( ...@@ -142,8 +148,7 @@ def _vectorized(
core_input_types = [] core_input_types = []
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True): for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
core_ndim = input_type.ndim - len(bc_pattern) core_ndim = input_type.ndim - len(bc_pattern)
# TODO: Reconsider this if allow_core_scalar and core_ndim == 0:
if core_ndim == 0:
core_input_type = input_type.dtype core_input_type = input_type.dtype
else: else:
core_input_type = types.Array( core_input_type = types.Array(
...@@ -196,7 +201,7 @@ def _vectorized( ...@@ -196,7 +201,7 @@ def _vectorized(
sig, sig,
args, args,
): ):
[_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args [_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) constant_inputs = cgutils.unpack_tuple(builder, constant_inputs)
inputs = cgutils.unpack_tuple(builder, inputs) inputs = cgutils.unpack_tuple(builder, inputs)
...@@ -256,6 +261,7 @@ def _vectorized( ...@@ -256,6 +261,7 @@ def _vectorized(
output_bc_patterns_val, output_bc_patterns_val,
input_types, input_types,
output_types, output_types,
core_scalar=allow_core_scalar,
) )
if len(outputs) == 1: if len(outputs) == 1:
...@@ -429,6 +435,7 @@ def make_loop_call( ...@@ -429,6 +435,7 @@ def make_loop_call(
output_bc: tuple[tuple[bool, ...], ...], output_bc: tuple[tuple[bool, ...], ...],
input_types: tuple[Any, ...], input_types: tuple[Any, ...],
output_types: tuple[Any, ...], output_types: tuple[Any, ...],
core_scalar: bool = True,
): ):
safe = (False, False) safe = (False, False)
...@@ -486,7 +493,7 @@ def make_loop_call( ...@@ -486,7 +493,7 @@ def make_loop_call(
idxs_bc, idxs_bc,
*safe, *safe,
) )
if core_ndim == 0: if core_scalar and core_ndim == 0:
# Retrive scalar item at index # Retrive scalar item at index
val = builder.load(ptr) val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set) # val.set_metadata("alias.scope", input_scope_set)
...@@ -499,15 +506,19 @@ def make_loop_call( ...@@ -499,15 +506,19 @@ def make_loop_call(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
) )
core_array = context.make_array(core_arry_type)(context, builder) core_array = context.make_array(core_arry_type)(context, builder)
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:] core_shape = cgutils.unpack_tuple(builder, input.shape)[
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:] input_type.ndim - core_ndim :
]
core_strides = cgutils.unpack_tuple(builder, input.strides)[
input_type.ndim - core_ndim :
]
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype)) itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
context.populate_array( context.populate_array(
core_array, core_array,
# TODO whey do we need to bitcast? # TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type), data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape), shape=core_shape,
strides=cgutils.pack_array(builder, core_strides), strides=core_strides,
itemsize=context.get_constant(types.intp, itemsize), itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about? # TODO what is meminfo about?
meminfo=None, meminfo=None,
......
...@@ -2,10 +2,12 @@ import numpy as np ...@@ -2,10 +2,12 @@ import numpy as np
import pytest import pytest
from pytensor import function from pytensor import function
from pytensor.tensor import lvector, tensor, tensor3 from pytensor.graph import Apply
from pytensor.scalar import ScalarOp
from pytensor.tensor import TensorVariable, lvector, tensor, tensor3, vector
from pytensor.tensor.basic import Alloc, ARange, constant from pytensor.tensor.basic import Alloc, ARange, constant
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.nlinalg import SVD, Det
from pytensor.tensor.slinalg import Cholesky, cholesky from pytensor.tensor.slinalg import Cholesky, cholesky
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
...@@ -90,3 +92,39 @@ def test_blockwise_scalar_dimshuffle(): ...@@ -90,3 +92,39 @@ def test_blockwise_scalar_dimshuffle():
) )
out = blockwise_scalar_ds(x) out = blockwise_scalar_ds(x)
compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False) compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False)
def test_blockwise_vs_elemwise_scalar_op():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1760
class TestScalarOp(ScalarOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outputs):
[x] = inputs
if isinstance(node.inputs[0], TensorVariable):
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, np.number | float)
out = x + 1
if isinstance(node.outputs[0], TensorVariable):
out = np.asarray(out)
outputs[0][0] = out
x = vector("x")
y = Elemwise(TestScalarOp())(x)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run TestScalarOp's perform method",
):
fn = function([x], y, mode="NUMBA")
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])
z = Blockwise(TestScalarOp(), signature="()->()")(x)
with pytest.warns(
UserWarning,
match="Numba will use object mode to run TestScalarOp's perform method",
):
fn = function([x], z, mode="NUMBA")
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论