提交 78f4d2da authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move NonZero Op dispatcher to tensor_basic

上级 306aceb5
...@@ -24,7 +24,6 @@ from pytensor.link.utils import ( ...@@ -24,7 +24,6 @@ from pytensor.link.utils import (
) )
from pytensor.scalar.basic import ScalarType from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
...@@ -457,15 +456,3 @@ def numba_funcify_IfElse(op, **kwargs): ...@@ -457,15 +456,3 @@ def numba_funcify_IfElse(op, **kwargs):
return res[0] return res[0]
return ifelse return ifelse
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
return result_tuple[0]
return list(result_tuple)
return nonzero
...@@ -3,7 +3,11 @@ from textwrap import indent ...@@ -3,7 +3,11 @@ from textwrap import indent
import numpy as np import numpy as np
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify,
numba_njit,
)
from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.link.utils import compile_function_src, unique_name_generator
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
...@@ -13,6 +17,7 @@ from pytensor.tensor.basic import ( ...@@ -13,6 +17,7 @@ from pytensor.tensor.basic import (
Eye, Eye,
Join, Join,
MakeVector, MakeVector,
Nonzero,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
...@@ -235,3 +240,15 @@ def numba_funcify_ScalarFromTensor(op, **kwargs): ...@@ -235,3 +240,15 @@ def numba_funcify_ScalarFromTensor(op, **kwargs):
return numba_basic.to_scalar(x) return numba_basic.to_scalar(x)
return scalar_from_tensor return scalar_from_tensor
@numba_funcify.register(Nonzero)
def numba_funcify_Nonzero(op, node, **kwargs):
@numba_njit
def nonzero(a):
result_tuple = np.nonzero(a)
if a.ndim == 1:
return result_tuple[0]
return list(result_tuple)
return nonzero
...@@ -718,20 +718,6 @@ def test_function_overhead(mode, benchmark): ...@@ -718,20 +718,6 @@ def test_function_overhead(mode, benchmark):
benchmark(fn, test_x) benchmark(fn, test_x)
@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
a = pt.tensor("a", shape=(None,) * input_data.ndim)
graph_outputs = pt.nonzero(a)
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) @pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
def test_mat_vec_dot_performance(dtype, benchmark): def test_mat_vec_dot_performance(dtype, benchmark):
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
......
...@@ -326,3 +326,17 @@ def test_Eye(n, m, k, dtype): ...@@ -326,3 +326,17 @@ def test_Eye(n, m, k, dtype):
g, g,
[n_test, m_test] if m is not None else [n_test], [n_test, m_test] if m is not None else [n_test],
) )
@pytest.mark.parametrize(
"input_data",
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
)
def test_Nonzero(input_data):
a = pt.tensor("a", shape=(None,) * input_data.ndim)
graph_outputs = pt.nonzero(a)
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论