提交 f85b63ca authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove stale Assert tests

These tests were covering things that don't exist anymore. params in python perform method of Ops, or misbehavior of an Op not respecting the signature
上级 3f457d07
......@@ -233,37 +233,23 @@ def generate_fallback_impl(op, node, storage_map=None, **kwargs):
node.dprint(depth=5, print_type=True)
n_outputs = len(node.outputs)
single_out = n_outputs == 1
if n_outputs > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
if single_out:
ret_sig = get_numba_type(node.outputs[0].type)
output_types = tuple(out.type for out in node.outputs)
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
return outputs
if n_outputs == 1:
def py_perform_return(inputs):
return output_types[0].filter(py_perform(inputs)[0][0])
else:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
def py_perform_return(inputs):
# zip strict not specified because we are in a hot loop
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs))
)
def py_perform(inputs):
output_storage = [[None] for _i in range(n_outputs)]
op.perform(node, inputs, output_storage)
outputs = tuple(o[0] for o in output_storage)
return outputs[0] if single_out else outputs
@numba_njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
ret = py_perform_return(inputs)
ret = py_perform(inputs)
return ret
return perform
......
......@@ -26,7 +26,6 @@ from pytensor.graph.type import Type
from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor.elemwise import Elemwise
......@@ -372,32 +371,6 @@ def test_perform(inputs, op, exc):
)
def test_perform_params():
"""This tests for `Op.perform` implementations that require the `params` arguments."""
x = pt.vector(shape=(2,))
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x, np.array(True))
compare_numba_and_py([x], out, [x_test_value])
def test_perform_type_convert():
"""This tests the use of `Type.filter` in `objmode`.
The `Op.perform` takes a single input that it returns as-is, but it gets a
native scalar and it's supposed to return an `np.ndarray`.
"""
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
def test_shared():
a = shared(np.array([1, 2, 3], dtype=config.floatX))
......
......@@ -5,6 +5,7 @@ import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.raise_op import assert_op
from pytensor.tensor import extra_ops
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -383,3 +384,12 @@ def test_Searchsorted(a, v, side, sorter, exc):
g,
[test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
)
def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论