提交 78f4d2da authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move NonZero Op dispatcher to tensor_basic

上级 306aceb5
......@@ -24,7 +24,6 @@ from pytensor.link.utils import (
)
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.type import TensorType
......@@ -457,15 +456,3 @@ def numba_funcify_IfElse(op, **kwargs):
return res[0]
return ifelse
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
return result_tuple[0]
return list(result_tuple)
return nonzero
......@@ -3,7 +3,11 @@ from textwrap import indent
import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify,
numba_njit,
)
from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.basic import (
Alloc,
......@@ -13,6 +17,7 @@ from pytensor.tensor.basic import (
Eye,
Join,
MakeVector,
Nonzero,
ScalarFromTensor,
Split,
TensorFromScalar,
......@@ -235,3 +240,15 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
return numba_basic.to_scalar(x)
return scalar_from_tensor
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
return result_tuple[0]
return list(result_tuple)
return nonzero
......@@ -718,20 +718,6 @@ def test_function_overhead(mode, benchmark):
benchmark(fn, test_x)
@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
a = pt.tensor("a", shape=(None,) * input_data.ndim)
graph_outputs = pt.nonzero(a)
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
def test_mat_vec_dot_performance(dtype, benchmark):
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
......
......@@ -326,3 +326,17 @@ def test_Eye(n, m, k, dtype):
g,
[n_test, m_test] if m is not None else [n_test],
)
@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
a = pt.tensor("a", shape=(None,) * input_data.ndim)
graph_outputs = pt.nonzero(a)
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论