Unverified 提交 b75c18fe authored 作者: Abhinav Khot's avatar Abhinav Khot 提交者: GitHub

Add numba overload for Nonzero (#1289)

* Add numba overload for Nonzero * added numba backend and testsfor Nonzero * Added numba backend for Nonzero * Modified the tests and the dispatch for efficiency
上级 39704d10
......@@ -33,6 +33,7 @@ from pytensor.link.utils import (
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -657,3 +658,15 @@ def numba_funcify_IfElse(op, **kwargs):
return res[0]
return ifelse
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
return result_tuple[0]
return list(result_tuple)
return nonzero
......@@ -293,7 +293,6 @@ def compare_numba_and_py(
)
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy)
if isinstance(graph_outputs, tuple | list):
for j, p in zip(numba_res, py_res, strict=True):
assert_fn(j, p)
......@@ -899,3 +898,17 @@ def test_function_overhead(mode, benchmark):
assert np.sum(fn(test_x)) == 1000
benchmark(fn, test_x)
@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
a = pt.tensor("a", shape=(None,) * input_data.ndim)
graph_outputs = pt.nonzero(a)
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论