Unverified 提交 b75c18fe authored 作者: Abhinav Khot's avatar Abhinav Khot 提交者: GitHub

Add numba overload for Nonzero (#1289)

* Add numba overload for Nonzero * added numba backend and testsfor Nonzero * Added numba backend for Nonzero * Modified the tests and the dispatch for efficiency
上级 39704d10
...@@ -33,6 +33,7 @@ from pytensor.link.utils import ( ...@@ -33,6 +33,7 @@ from pytensor.link.utils import (
from pytensor.scalar.basic import ScalarType from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -657,3 +658,15 @@ def numba_funcify_IfElse(op, **kwargs): ...@@ -657,3 +658,15 @@ def numba_funcify_IfElse(op, **kwargs):
return res[0] return res[0]
return ifelse return ifelse
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
return result_tuple[0]
return list(result_tuple)
return nonzero
...@@ -293,7 +293,6 @@ def compare_numba_and_py( ...@@ -293,7 +293,6 @@ def compare_numba_and_py(
) )
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy) numba_res = pytensor_numba_fn(*test_inputs_copy)
if isinstance(graph_outputs, tuple | list): if isinstance(graph_outputs, tuple | list):
for j, p in zip(numba_res, py_res, strict=True): for j, p in zip(numba_res, py_res, strict=True):
assert_fn(j, p) assert_fn(j, p)
...@@ -899,3 +898,17 @@ def test_function_overhead(mode, benchmark): ...@@ -899,3 +898,17 @@ def test_function_overhead(mode, benchmark):
assert np.sum(fn(test_x)) == 1000 assert np.sum(fn(test_x)) == 1000
benchmark(fn, test_x) benchmark(fn, test_x)
@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
a = pt.tensor("a", shape=(None,) * input_data.ndim)
graph_outputs = pt.nonzero(a)
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论