提交 51430183 authored 作者: Brendan Murphy's avatar Brendan Murphy 提交者: Ricardo Vieira

Fix test for neg on unsigned

Due to changes in numpy conversion rules (NEP 50), overflows are not ignored; in particular, negating a unsigned int causes an overflow error. The test for `neg` has been changed to check that this error is raised.
上级 4c8c8b6e
......@@ -23,6 +23,7 @@ from pytensor.graph.basic import Variable, ancestors, applys_between, equal_comp
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker
from pytensor.npy_2_compat import using_numpy_2
from pytensor.printing import pprint
from pytensor.raise_op import Assert
from pytensor.tensor import blas, blas_c
......@@ -391,11 +392,20 @@ TestAbsBroadcast = makeBroadcastTester(
grad=_grad_broadcast_unary_normal,
)
# in numpy >= 2.0, negating a uint raises an error
neg_good = _good_broadcast_unary_normal.copy()
if using_numpy_2:
neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")}
else:
neg_bad = None
TestNegBroadcast = makeBroadcastTester(
op=neg,
expected=lambda x: -x,
good=_good_broadcast_unary_normal,
good=neg_good,
grad=_grad_broadcast_unary_normal,
bad_compile=neg_bad,
)
TestSgnBroadcast = makeBroadcastTester(
......
......@@ -339,6 +339,7 @@ def makeTester(
good=None,
bad_build=None,
bad_runtime=None,
bad_compile=None,
grad=None,
mode=None,
grad_rtol=None,
......@@ -373,6 +374,7 @@ def makeTester(
_test_memmap = test_memmap
_check_name = check_name
_grad_eps = grad_eps
_bad_compile = bad_compile or {}
class Checker:
op = staticmethod(_op)
......@@ -382,6 +384,7 @@ def makeTester(
good = _good
bad_build = _bad_build
bad_runtime = _bad_runtime
bad_compile = _bad_compile
grad = _grad
mode = _mode
skip = skip_
......@@ -539,6 +542,24 @@ def makeTester(
# instantiated on the following bad inputs: %s"
# % (self.op, testname, node, inputs))
@config.change_flags(compute_test_value="off")
@pytest.mark.skipif(skip, reason="Skipped")
def test_bad_compile(self):
for testname, inputs in self.bad_compile.items():
inputrs = [shared(input) for input in inputs]
try:
node = safe_make_node(self.op, *inputrs)
except Exception as exc:
err_msg = (
f"Test {self.op}::{testname}: Error occurred while trying"
f" to make a node with inputs {inputs}"
)
exc.args += (err_msg,)
raise
with pytest.raises(Exception):
inplace_func([], node.outputs, mode=mode, name="test_bad_runtime")
@config.change_flags(compute_test_value="off")
@pytest.mark.skipif(skip, reason="Skipped")
def test_bad_runtime(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论