提交 2fcb9b2c authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Fix tests and fix scalar numba return types

上级 48f4db7f
......@@ -204,7 +204,7 @@ enable_slice_boxing()
def to_scalar(x):
raise NotImplementedError()
return np.asarray(x).item()
@numba.extending.overload(to_scalar)
......@@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return z
return np.asarray(z)
"""
return subtensor_def_src
......@@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs):
@numba_njit
def shape_i(x):
return np.shape(x)[i]
return np.asarray(np.shape(x)[i])
return shape_i
......
......@@ -9,6 +9,7 @@ import numba
import numpy as np
from numba import TypingError, types
from numba.core import cgutils
from numba.core.extending import overload
from numba.np import arrayobj
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
......@@ -174,6 +175,7 @@ def create_axis_reducer(
ndim: int,
dtype: numba.types.Type,
keepdims: bool = False,
return_scalar=False,
) -> numba.core.dispatcher.Dispatcher:
r"""Create Python function that performs a NumPy-like reduction on a given axis.
......@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
return_expr = "res" if keepdims else "res.item()"
if not return_scalar:
return_expr = f"np.asarray({return_expr})"
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
......@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
def create_multiaxis_reducer(
scalar_op, identity, axes, ndim, dtype, input_name="input"
scalar_op,
identity,
axes,
ndim,
dtype,
input_name="input",
return_scalar=False,
):
r"""Construct a function that reduces multiple axes.
......@@ -336,6 +346,8 @@ def create_multiaxis_reducer(
The number of dimensions of the result.
dtype:
The data type of the result.
return_scalar:
If True, return a scalar, otherwise an array.
Returns
=======
......@@ -370,10 +382,17 @@ def create_multiaxis_reducer(
)
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
if not return_scalar:
pre_result = "np.asarray"
post_result = ""
else:
pre_result = "np.asarray"
post_result = ".item()"
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
return np.asarray({var_name})
return {pre_result}({var_name}){post_result}
"""
careduce_fn = compile_function_src(
......@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
return careduce_fn
def jit_compile_reducer(node, fn, **kwds):
def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
"""Compile Python source for reduction loops using additional optimizations.
Parameters
......@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
A :func:`numba.njit`-compiled function.
"""
signature = create_numba_signature(node, reduce_to_scalar=True)
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
......@@ -618,23 +637,58 @@ def numba_funcify_Elemwise(op, node, **kwargs):
inplace_pattern = tuple(op.inplace_pattern.items())
# numba doesn't support nested literals right now...
input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode()
output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
output_bc_patterns_enc = base64.encodebytes(
pickle.dumps(output_bc_patterns)
).decode()
output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
@numba_njit
def elemwise_wrapper(*inputs):
return _vectorized(
scalar_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
inputs,
)
return elemwise_wrapper
# Pure python implementation, that will be used in tests
def elemwise(*inputs):
inputs = [np.asarray(input) for input in inputs]
inputs_bc = np.broadcast_arrays(*inputs)
shape = inputs[0].shape
for input, bc in zip(inputs, input_bc_patterns):
for length, allow_bc, iter_length in zip(input.shape, bc, shape):
if length == 1 and shape and iter_length != 1 and not allow_bc:
raise ValueError("Broadcast not allowed.")
outputs = []
for dtype in output_dtypes:
outputs.append(np.empty(shape, dtype=dtype))
for idx in np.ndindex(shape):
vals = [input[idx] for input in inputs_bc]
outs = scalar_op_fn(*vals)
if not isinstance(outs, tuple):
outs = (outs,)
for out, out_val in zip(outputs, outs):
out[idx] = out_val
outputs_summed = []
for output, bc in zip(outputs, output_bc_patterns):
axes = tuple(np.nonzero(bc)[0])
outputs_summed.append(output.sum(axes, keepdims=True))
if len(outputs_summed) != 1:
return tuple(outputs_summed)
return outputs_summed[0]
@overload(elemwise)
def ov_elemwise(*inputs):
return elemwise_wrapper
return elemwise
@numba_funcify.register(Sum)
......@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
if axes is None:
axes = list(range(node.inputs[0].ndim))
axes = list(axes)
axes = tuple(axes)
ndim_input = node.inputs[0].ndim
......@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
@numba_njit(fastmath=True)
def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype...
return np.asarray(array.sum()).astype(np_acc_dtype)
return np.asarray(array.sum(), dtype=np_acc_dtype)
else:
elif len(axes) == 0:
@numba_njit(fastmath=True)
def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype...
return array.sum(axes).astype(np_acc_dtype)
return array
else:
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)
return impl_sum
......@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
input_name=input_name,
)
careduce_fn = jit_compile_reducer(node, careduce_py_fn)
careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
return careduce_fn
......@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
scalar_maximum,
-np.inf,
axis,
x_at.ndim,
x_dtype,
keepdims=True,
)
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
......@@ -935,10 +995,17 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum, -np.inf, axes, x_ndim, x_dtype
scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
)
reduced_x_ndim = x_ndim - len(axes) + 1
......
......@@ -117,19 +117,6 @@ def make_loop_call(
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref?
inner_func = context.compile_subroutine(
builder,
# I don't quite understand why we need to access `dispatcher` here.
# The object does seem to be a dispatcher already? But it is missing
# attributes...
scalar_func.dispatcher,
scalar_signature,
caching=False,
)
inner = inner_func.fndesc
# Extract shape and stride information from the array.
# For later use in the loop body to do the indexing
def extract_array(aryty, obj):
......@@ -191,14 +178,15 @@ def make_loop_call(
# val.set_metadata("noalias", output_scope_set)
input_vals.append(val)
# Call scalar function
output_values = context.call_internal(
builder,
inner,
scalar_signature,
input_vals,
)
if isinstance(scalar_signature.return_type, types.Tuple):
inner_codegen = context.get_function(scalar_func, scalar_signature)
if isinstance(
scalar_signature.args[0], (types.StarArgTuple, types.StarArgUniTuple)
):
input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)]
output_values = inner_codegen(builder, input_vals)
if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)):
output_values = cgutils.unpack_tuple(builder, output_values)
else:
output_values = [output_values]
......
......@@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
lambda _: 0, len(node.inputs) - 1
)
# TODO broadcastable checks
@numba_basic.numba_njit
def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple()
......
......@@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
if not hasattr(op, "nfunc_spec"):
return generate_fallback_impl(op, node, **kwargs)
scalar_func_path = op.nfunc_spec[0]
scalar_func_numba = None
......
......@@ -17,7 +17,11 @@ from pytensor.tensor.type import TensorType
def idx_to_str(
array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i"
array_name: str,
offset: int,
size: Optional[str] = None,
idx_symbol: str = "i",
allow_scalar=False,
) -> str:
if offset < 0:
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
......@@ -32,7 +36,10 @@ def idx_to_str(
# compensate for this poor `Op`/rewrite design and implementation.
indices = f"({indices}) % {size}"
return f"{array_name}[{indices}]"
if allow_scalar:
return f"{array_name}[{indices}]"
else:
return f"np.asarray({array_name}[{indices}])"
@overload(range)
......@@ -115,7 +122,9 @@ def numba_funcify_Scan(op, node, **kwargs):
indexed_inner_in_str = (
storage_name
if tap_offset is None
else idx_to_str(storage_name, tap_offset, size=storage_size_var)
else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
)
)
inner_in_exprs.append(indexed_inner_in_str)
......@@ -232,7 +241,12 @@ def numba_funcify_Scan(op, node, **kwargs):
)
for out_tap in output_taps:
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, out_tap, size=storage_size_name)
idx_to_str(
storage_name,
out_tap,
size=storage_size_name,
allow_scalar=True,
)
)
add_output_storage_post_proc_stmt(
......@@ -269,7 +283,7 @@ def numba_funcify_Scan(op, node, **kwargs):
storage_size_name = f"{outer_in_name}_len"
inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name)
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
)
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
......
......@@ -27,9 +27,9 @@ class NumbaLinker(JITLinker):
return numba_funcify(fgraph, **kwargs)
def jit_compile(self, fn):
import numba
from pytensor.link.numba.dispatch.basic import numba_njit
jitted_fn = numba.njit(fn)
jitted_fn = numba_njit(fn)
return jitted_fn
def create_thunk_inputs(self, storage_map):
......
......@@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas
from pytensor.tensor import subtensor as at_subtensor
from pytensor.tensor.elemwise import Elemwise
......@@ -63,6 +64,33 @@ class MySingleOut(Op):
outputs[0][0] = res
class ScalarMyMultiOut(ScalarOp):
nin = 2
nout = 2
@staticmethod
def impl(a, b):
res1 = 2 * a
res2 = 2 * b
return [res1, res2]
def make_node(self, a, b):
a = as_scalar(a)
b = as_scalar(b)
return Apply(self, [a, b], [a.type(), b.type()])
def perform(self, node, inputs, outputs):
res1, res2 = self.impl(inputs[0], inputs[1])
outputs[0][0] = res1
outputs[1][0] = res2
scalar_my_multi_out = Elemwise(ScalarMyMultiOut())
scalar_my_multi_out.ufunc = ScalarMyMultiOut.impl
scalar_my_multi_out.ufunc.nin = 2
scalar_my_multi_out.ufunc.nout = 2
class MyMultiOut(Op):
nin = 2
nout = 2
......@@ -86,7 +114,6 @@ my_multi_out = Elemwise(MyMultiOut())
my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......@@ -988,8 +1015,8 @@ def test_config_options_parallel():
x = at.dvector()
with config.change_flags(numba__vectorize_target="parallel"):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["parallel"] is True
......@@ -997,8 +1024,9 @@ def test_config_options_fastmath():
x = at.dvector()
with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__.keys()))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True
......@@ -1006,16 +1034,14 @@ def test_config_options_cached():
x = at.dvector()
with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert not isinstance(
numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache
)
pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert not isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache)
pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion():
......
......@@ -16,7 +16,7 @@ from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZero
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
compare_numba_and_py,
my_multi_out,
scalar_my_multi_out,
set_test_value,
)
......@@ -99,8 +99,8 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX),
rng.standard_normal(100).astype(config.floatX),
],
lambda x, y: my_multi_out(x, y),
NotImplementedError,
lambda x, y: scalar_my_multi_out(x, y),
None,
),
],
)
......
......@@ -32,6 +32,7 @@ def test_Bartlett(val):
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15),
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论