提交 f85b63ca authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove stale Assert tests

These tests were covering things that don't exist anymore. params in python perform method of Ops, or misbehavior of an Op not respecting the signature
上级 3f457d07
...@@ -233,37 +233,23 @@ def generate_fallback_impl(op, node, storage_map=None, **kwargs): ...@@ -233,37 +233,23 @@ def generate_fallback_impl(op, node, storage_map=None, **kwargs):
node.dprint(depth=5, print_type=True) node.dprint(depth=5, print_type=True)
n_outputs = len(node.outputs) n_outputs = len(node.outputs)
single_out = n_outputs == 1
if n_outputs > 1: if single_out:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type) ret_sig = get_numba_type(node.outputs[0].type)
output_types = tuple(out.type for out in node.outputs)
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs
if n_outputs == 1:
def py_perform_return(inputs):
return output_types[0].filter(py_perform(inputs)[0][0])
else: else:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
def py_perform_return(inputs): def py_perform(inputs):
# zip strict not specified because we are in a hot loop output_storage = [[None] for _i in range(n_outputs)]
return tuple( op.perform(node, inputs, output_storage)
out_type.filter(out[0]) outputs = tuple(o[0] for o in output_storage)
for out_type, out in zip(output_types, py_perform(inputs)) return outputs[0] if single_out else outputs
)
@numba_njit @numba_njit
def perform(*inputs): def perform(*inputs):
with numba.objmode(ret=ret_sig): with numba.objmode(ret=ret_sig):
ret = py_perform_return(inputs) ret = py_perform(inputs)
return ret return ret
return perform return perform
......
...@@ -26,7 +26,6 @@ from pytensor.graph.type import Type ...@@ -26,7 +26,6 @@ from pytensor.graph.type import Type
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -372,32 +371,6 @@ def test_perform(inputs, op, exc): ...@@ -372,32 +371,6 @@ def test_perform(inputs, op, exc):
) )
def test_perform_params():
"""This tests for `Op.perform` implementations that require the `params` arguments."""
x = pt.vector(shape=(2,))
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x, np.array(True))
compare_numba_and_py([x], out, [x_test_value])
def test_perform_type_convert():
"""This tests the use of `Type.filter` in `objmode`.
The `Op.perform` takes a single input that it returns as-is, but it gets a
native scalar and it's supposed to return an `np.ndarray`.
"""
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
def test_shared(): def test_shared():
a = shared(np.array([1, 2, 3], dtype=config.floatX)) a = shared(np.array([1, 2, 3], dtype=config.floatX))
......
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.raise_op import assert_op
from pytensor.tensor import extra_ops from pytensor.tensor import extra_ops
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -383,3 +384,12 @@ def test_Searchsorted(a, v, side, sorter, exc): ...@@ -383,3 +384,12 @@ def test_Searchsorted(a, v, side, sorter, exc):
g, g,
[test_a, test_v] if sorter is None else [test_a, test_v, test_sorter], [test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
) )
def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论