提交 2fcb9b2c authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Fix tests and fix scalar numba return types

上级 48f4db7f
...@@ -204,7 +204,7 @@ enable_slice_boxing() ...@@ -204,7 +204,7 @@ enable_slice_boxing()
def to_scalar(x): def to_scalar(x):
raise NotImplementedError() return np.asarray(x).item()
@numba.extending.overload(to_scalar) @numba.extending.overload(to_scalar)
...@@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}): ...@@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}):
{index_prologue} {index_prologue}
{indices_creation_src} {indices_creation_src}
{index_body} {index_body}
return z return np.asarray(z)
""" """
return subtensor_def_src return subtensor_def_src
...@@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs): ...@@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs):
@numba_njit @numba_njit
def shape_i(x): def shape_i(x):
return np.shape(x)[i] return np.asarray(np.shape(x)[i])
return shape_i return shape_i
......
...@@ -9,6 +9,7 @@ import numba ...@@ -9,6 +9,7 @@ import numba
import numpy as np import numpy as np
from numba import TypingError, types from numba import TypingError, types
from numba.core import cgutils from numba.core import cgutils
from numba.core.extending import overload
from numba.np import arrayobj from numba.np import arrayobj
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
...@@ -174,6 +175,7 @@ def create_axis_reducer( ...@@ -174,6 +175,7 @@ def create_axis_reducer(
ndim: int, ndim: int,
dtype: numba.types.Type, dtype: numba.types.Type,
keepdims: bool = False, keepdims: bool = False,
return_scalar=False,
) -> numba.core.dispatcher.Dispatcher: ) -> numba.core.dispatcher.Dispatcher:
r"""Create Python function that performs a NumPy-like reduction on a given axis. r"""Create Python function that performs a NumPy-like reduction on a given axis.
...@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x): ...@@ -284,6 +286,8 @@ def {reduce_elemwise_fn_name}(x):
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
return_expr = "res" if keepdims else "res.item()" return_expr = "res" if keepdims else "res.item()"
if not return_scalar:
return_expr = f"np.asarray({return_expr})"
reduce_elemwise_def_src = f""" reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x): def {reduce_elemwise_fn_name}(x):
...@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x): ...@@ -305,7 +309,13 @@ def {reduce_elemwise_fn_name}(x):
def create_multiaxis_reducer( def create_multiaxis_reducer(
scalar_op, identity, axes, ndim, dtype, input_name="input" scalar_op,
identity,
axes,
ndim,
dtype,
input_name="input",
return_scalar=False,
): ):
r"""Construct a function that reduces multiple axes. r"""Construct a function that reduces multiple axes.
...@@ -336,6 +346,8 @@ def create_multiaxis_reducer( ...@@ -336,6 +346,8 @@ def create_multiaxis_reducer(
The number of dimensions of the result. The number of dimensions of the result.
dtype: dtype:
The data type of the result. The data type of the result.
return_scalar:
If True, return a scalar, otherwise an array.
Returns Returns
======= =======
...@@ -370,10 +382,17 @@ def create_multiaxis_reducer( ...@@ -370,10 +382,17 @@ def create_multiaxis_reducer(
) )
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
if not return_scalar:
pre_result = "np.asarray"
post_result = ""
else:
pre_result = "np.asarray"
post_result = ".item()"
careduce_def_src = f""" careduce_def_src = f"""
def {careduce_fn_name}({input_name}): def {careduce_fn_name}({input_name}):
{careduce_assign_lines} {careduce_assign_lines}
return np.asarray({var_name}) return {pre_result}({var_name}){post_result}
""" """
careduce_fn = compile_function_src( careduce_fn = compile_function_src(
...@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}): ...@@ -383,7 +402,7 @@ def {careduce_fn_name}({input_name}):
return careduce_fn return careduce_fn
def jit_compile_reducer(node, fn, **kwds): def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
"""Compile Python source for reduction loops using additional optimizations. """Compile Python source for reduction loops using additional optimizations.
Parameters Parameters
...@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds): ...@@ -400,7 +419,7 @@ def jit_compile_reducer(node, fn, **kwds):
A :func:`numba.njit`-compiled function. A :func:`numba.njit`-compiled function.
""" """
signature = create_numba_signature(node, reduce_to_scalar=True) signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
# Eagerly compile the function using increased optimizations. This should # Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions. # help improve nested loop reductions.
...@@ -618,24 +637,59 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -618,24 +637,59 @@ def numba_funcify_Elemwise(op, node, **kwargs):
inplace_pattern = tuple(op.inplace_pattern.items()) inplace_pattern = tuple(op.inplace_pattern.items())
# numba doesn't support nested literals right now... # numba doesn't support nested literals right now...
input_bc_patterns = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
output_bc_patterns = base64.encodebytes(pickle.dumps(output_bc_patterns)).decode() output_bc_patterns_enc = base64.encodebytes(
output_dtypes = base64.encodebytes(pickle.dumps(output_dtypes)).decode() pickle.dumps(output_bc_patterns)
inplace_pattern = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() ).decode()
output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
@numba_njit
def elemwise_wrapper(*inputs): def elemwise_wrapper(*inputs):
return _vectorized( return _vectorized(
scalar_op_fn, scalar_op_fn,
input_bc_patterns, input_bc_patterns_enc,
output_bc_patterns, output_bc_patterns_enc,
output_dtypes, output_dtypes_enc,
inplace_pattern, inplace_pattern_enc,
inputs, inputs,
) )
# Pure python implementation, that will be used in tests
def elemwise(*inputs):
inputs = [np.asarray(input) for input in inputs]
inputs_bc = np.broadcast_arrays(*inputs)
shape = inputs[0].shape
for input, bc in zip(inputs, input_bc_patterns):
for length, allow_bc, iter_length in zip(input.shape, bc, shape):
if length == 1 and shape and iter_length != 1 and not allow_bc:
raise ValueError("Broadcast not allowed.")
outputs = []
for dtype in output_dtypes:
outputs.append(np.empty(shape, dtype=dtype))
for idx in np.ndindex(shape):
vals = [input[idx] for input in inputs_bc]
outs = scalar_op_fn(*vals)
if not isinstance(outs, tuple):
outs = (outs,)
for out, out_val in zip(outputs, outs):
out[idx] = out_val
outputs_summed = []
for output, bc in zip(outputs, output_bc_patterns):
axes = tuple(np.nonzero(bc)[0])
outputs_summed.append(output.sum(axes, keepdims=True))
if len(outputs_summed) != 1:
return tuple(outputs_summed)
return outputs_summed[0]
@overload(elemwise)
def ov_elemwise(*inputs):
return elemwise_wrapper return elemwise_wrapper
return elemwise
@numba_funcify.register(Sum) @numba_funcify.register(Sum)
def numba_funcify_Sum(op, node, **kwargs): def numba_funcify_Sum(op, node, **kwargs):
...@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs): ...@@ -643,7 +697,7 @@ def numba_funcify_Sum(op, node, **kwargs):
if axes is None: if axes is None:
axes = list(range(node.inputs[0].ndim)) axes = list(range(node.inputs[0].ndim))
axes = list(axes) axes = tuple(axes)
ndim_input = node.inputs[0].ndim ndim_input = node.inputs[0].ndim
...@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs): ...@@ -658,15 +712,16 @@ def numba_funcify_Sum(op, node, **kwargs):
@numba_njit(fastmath=True) @numba_njit(fastmath=True)
def impl_sum(array): def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype... return np.asarray(array.sum(), dtype=np_acc_dtype)
return np.asarray(array.sum()).astype(np_acc_dtype)
else: elif len(axes) == 0:
@numba_njit(fastmath=True) @numba_njit(fastmath=True)
def impl_sum(array): def impl_sum(array):
# TODO The accumulation itself should happen in acc_dtype... return array
return array.sum(axes).astype(np_acc_dtype)
else:
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)
return impl_sum return impl_sum
...@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -705,7 +760,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
input_name=input_name, input_name=input_name,
) )
careduce_fn = jit_compile_reducer(node, careduce_py_fn) careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
return careduce_fn return careduce_fn
...@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -888,7 +943,12 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim) axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer( reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum,
-np.inf,
axis,
x_at.ndim,
x_dtype,
keepdims=True,
) )
reduce_sum_py = create_axis_reducer( reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
...@@ -935,10 +995,17 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -935,10 +995,17 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
keep_axes = tuple(i for i in range(x_ndim) if i not in axes) keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max_py_fn = create_multiaxis_reducer( reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum, -np.inf, axes, x_ndim, x_dtype scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
) )
reduce_max = jit_compile_reducer( reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
) )
reduced_x_ndim = x_ndim - len(axes) + 1 reduced_x_ndim = x_ndim - len(axes) + 1
......
...@@ -117,19 +117,6 @@ def make_loop_call( ...@@ -117,19 +117,6 @@ def make_loop_call(
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape) # context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref?
inner_func = context.compile_subroutine(
builder,
# I don't quite understand why we need to access `dispatcher` here.
# The object does seem to be a dispatcher already? But it is missing
# attributes...
scalar_func.dispatcher,
scalar_signature,
caching=False,
)
inner = inner_func.fndesc
# Extract shape and stride information from the array. # Extract shape and stride information from the array.
# For later use in the loop body to do the indexing # For later use in the loop body to do the indexing
def extract_array(aryty, obj): def extract_array(aryty, obj):
...@@ -191,14 +178,15 @@ def make_loop_call( ...@@ -191,14 +178,15 @@ def make_loop_call(
# val.set_metadata("noalias", output_scope_set) # val.set_metadata("noalias", output_scope_set)
input_vals.append(val) input_vals.append(val)
# Call scalar function inner_codegen = context.get_function(scalar_func, scalar_signature)
output_values = context.call_internal(
builder, if isinstance(
inner, scalar_signature.args[0], (types.StarArgTuple, types.StarArgUniTuple)
scalar_signature, ):
input_vals, input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)]
) output_values = inner_codegen(builder, input_vals)
if isinstance(scalar_signature.return_type, types.Tuple):
if isinstance(scalar_signature.return_type, (types.Tuple, types.UniTuple)):
output_values = cgutils.unpack_tuple(builder, output_values) output_values = cgutils.unpack_tuple(builder, output_values)
else: else:
output_values = [output_values] output_values = [output_values]
......
...@@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs): ...@@ -364,6 +364,7 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
lambda _: 0, len(node.inputs) - 1 lambda _: 0, len(node.inputs) - 1
) )
# TODO broadcastable checks
@numba_basic.numba_njit @numba_basic.numba_njit
def broadcast_to(x, *shape): def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple() scalars_shape = create_zeros_tuple()
......
...@@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -38,6 +38,9 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up # TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again? # compiling the same Numba function over and over again?
if not hasattr(op, "nfunc_spec"):
return generate_fallback_impl(op, node, **kwargs)
scalar_func_path = op.nfunc_spec[0] scalar_func_path = op.nfunc_spec[0]
scalar_func_numba = None scalar_func_numba = None
......
...@@ -17,7 +17,11 @@ from pytensor.tensor.type import TensorType ...@@ -17,7 +17,11 @@ from pytensor.tensor.type import TensorType
def idx_to_str( def idx_to_str(
array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i" array_name: str,
offset: int,
size: Optional[str] = None,
idx_symbol: str = "i",
allow_scalar=False,
) -> str: ) -> str:
if offset < 0: if offset < 0:
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}" indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
...@@ -32,7 +36,10 @@ def idx_to_str( ...@@ -32,7 +36,10 @@ def idx_to_str(
# compensate for this poor `Op`/rewrite design and implementation. # compensate for this poor `Op`/rewrite design and implementation.
indices = f"({indices}) % {size}" indices = f"({indices}) % {size}"
if allow_scalar:
return f"{array_name}[{indices}]" return f"{array_name}[{indices}]"
else:
return f"np.asarray({array_name}[{indices}])"
@overload(range) @overload(range)
...@@ -115,7 +122,9 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -115,7 +122,9 @@ def numba_funcify_Scan(op, node, **kwargs):
indexed_inner_in_str = ( indexed_inner_in_str = (
storage_name storage_name
if tap_offset is None if tap_offset is None
else idx_to_str(storage_name, tap_offset, size=storage_size_var) else idx_to_str(
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
)
) )
inner_in_exprs.append(indexed_inner_in_str) inner_in_exprs.append(indexed_inner_in_str)
...@@ -232,7 +241,12 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -232,7 +241,12 @@ def numba_funcify_Scan(op, node, **kwargs):
) )
for out_tap in output_taps: for out_tap in output_taps:
inner_out_to_outer_in_stmts.append( inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, out_tap, size=storage_size_name) idx_to_str(
storage_name,
out_tap,
size=storage_size_name,
allow_scalar=True,
)
) )
add_output_storage_post_proc_stmt( add_output_storage_post_proc_stmt(
...@@ -269,7 +283,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -269,7 +283,7 @@ def numba_funcify_Scan(op, node, **kwargs):
storage_size_name = f"{outer_in_name}_len" storage_size_name = f"{outer_in_name}_len"
inner_out_to_outer_in_stmts.append( inner_out_to_outer_in_stmts.append(
idx_to_str(storage_name, 0, size=storage_size_name) idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
) )
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
......
...@@ -27,9 +27,9 @@ class NumbaLinker(JITLinker): ...@@ -27,9 +27,9 @@ class NumbaLinker(JITLinker):
return numba_funcify(fgraph, **kwargs) return numba_funcify(fgraph, **kwargs)
def jit_compile(self, fn): def jit_compile(self, fn):
import numba from pytensor.link.numba.dispatch.basic import numba_njit
jitted_fn = numba.njit(fn) jitted_fn = numba_njit(fn)
return jitted_fn return jitted_fn
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
......
...@@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic ...@@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas from pytensor.tensor import blas
from pytensor.tensor import subtensor as at_subtensor from pytensor.tensor import subtensor as at_subtensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -63,6 +64,33 @@ class MySingleOut(Op): ...@@ -63,6 +64,33 @@ class MySingleOut(Op):
outputs[0][0] = res outputs[0][0] = res
class ScalarMyMultiOut(ScalarOp):
nin = 2
nout = 2
@staticmethod
def impl(a, b):
res1 = 2 * a
res2 = 2 * b
return [res1, res2]
def make_node(self, a, b):
a = as_scalar(a)
b = as_scalar(b)
return Apply(self, [a, b], [a.type(), b.type()])
def perform(self, node, inputs, outputs):
res1, res2 = self.impl(inputs[0], inputs[1])
outputs[0][0] = res1
outputs[1][0] = res2
scalar_my_multi_out = Elemwise(ScalarMyMultiOut())
scalar_my_multi_out.ufunc = ScalarMyMultiOut.impl
scalar_my_multi_out.ufunc.nin = 2
scalar_my_multi_out.ufunc.nout = 2
class MyMultiOut(Op): class MyMultiOut(Op):
nin = 2 nin = 2
nout = 2 nout = 2
...@@ -86,7 +114,6 @@ my_multi_out = Elemwise(MyMultiOut()) ...@@ -86,7 +114,6 @@ my_multi_out = Elemwise(MyMultiOut())
my_multi_out.ufunc = MyMultiOut.impl my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2 my_multi_out.ufunc.nout = 2
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
...@@ -988,8 +1015,8 @@ def test_config_options_parallel(): ...@@ -988,8 +1015,8 @@ def test_config_options_parallel():
x = at.dvector() x = at.dvector()
with config.change_flags(numba__vectorize_target="parallel"): with config.change_flags(numba__vectorize_target="parallel"):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode) pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["parallel"] is True assert numba_mul_fn.targetoptions["parallel"] is True
...@@ -997,8 +1024,9 @@ def test_config_options_fastmath(): ...@@ -997,8 +1024,9 @@ def test_config_options_fastmath():
x = at.dvector() x = at.dvector()
with config.change_flags(numba__fastmath=True): with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode) pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__.keys()))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True assert numba_mul_fn.targetoptions["fastmath"] is True
...@@ -1006,16 +1034,14 @@ def test_config_options_cached(): ...@@ -1006,16 +1034,14 @@ def test_config_options_cached():
x = at.dvector() x = at.dvector()
with config.change_flags(numba__cache=True): with config.change_flags(numba__cache=True):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode) pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert not isinstance( assert not isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache
)
with config.change_flags(numba__cache=False): with config.change_flags(numba__cache=False):
pytensor_numba_fn = function([x], x * 2, mode=numba_mode) pytensor_numba_fn = function([x], at.sum(x), mode=numba_mode)
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["mul"] numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) assert isinstance(numba_mul_fn._cache, numba.core.caching.NullCache)
def test_scalar_return_value_conversion(): def test_scalar_return_value_conversion():
......
...@@ -16,7 +16,7 @@ from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZero ...@@ -16,7 +16,7 @@ from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZero
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
my_multi_out, scalar_my_multi_out,
set_test_value, set_test_value,
) )
...@@ -99,8 +99,8 @@ rng = np.random.default_rng(42849) ...@@ -99,8 +99,8 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
], ],
lambda x, y: my_multi_out(x, y), lambda x, y: scalar_my_multi_out(x, y),
NotImplementedError, None,
), ),
], ],
) )
......
...@@ -32,6 +32,7 @@ def test_Bartlett(val): ...@@ -32,6 +32,7 @@ def test_Bartlett(val):
for i in g_fg.inputs for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15),
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论