提交 6176f28c authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Remove unhelpful assert error messages in `test_math_opt.TestLocalErf(c)`

上级 4fd52e4e
......@@ -2751,17 +2751,11 @@ class TestLocalErf:
x = vector()
f = function([x], 1 + erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
mul,
erfc,
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc]
f(val)
f = function([x], erf(x) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
mul,
erfc,
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc]
f(val)
f = function([x], erf(x) + 2, mode=self.mode)
......@@ -2777,29 +2771,23 @@ class TestLocalErf:
x = vector()
f = function([x], 1 - erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erfc
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f(val)
f = function([x], 1 + (-erf(x)), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erfc
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f = function([x], (-erf(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erfc
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
f = function([x], 2 - erf(x), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2, f.maker.fgraph.toposort()
assert topo[0].op == erf, f.maker.fgraph.toposort()
assert isinstance(topo[1].op, Elemwise), f.maker.fgraph.toposort()
assert len(topo) == 2
assert topo[0].op == erf
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, aes.Add) or isinstance(
topo[1].op.scalar_op, aes.Sub
), f.maker.fgraph.toposort()
)
def test_local_erf_minus_one(self):
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
......@@ -2847,22 +2835,18 @@ class TestLocalErfc:
x = vector("x")
f = function([x], 1 - erfc(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erf
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f(val)
f = function([x], (-erfc(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erf
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], 2 - erfc(x), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2, f.maker.fgraph.toposort()
assert topo[0].op == erfc, f.maker.fgraph.toposort()
assert isinstance(topo[1].op, Elemwise), f.maker.fgraph.toposort()
assert isinstance(topo[1].op.scalar_op, aes.Sub), f.maker.fgraph.toposort()
assert len(topo) == 2
assert topo[0].op == erfc
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, aes.Sub)
def test_local_erf_neg_minus_one(self):
# test opt: (-1)+erfc(-x)=>erf(x)
......@@ -2870,20 +2854,14 @@ class TestLocalErfc:
x = vector("x")
f = function([x], -1 + erfc(-x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erf
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f(val)
f = function([x], erfc(-x) - 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erf
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
f = function([x], erfc(-x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [
erf
], f.maker.fgraph.toposort()
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
@pytest.mark.xfail()
def test_local_log_erfc(self):
......@@ -2902,17 +2880,17 @@ class TestLocalErfc:
mode_fusion.check_isfinite = False
f = function([x], log(erfc(x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert len(f.maker.fgraph.apply_nodes) == 23
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val)))
f = function([x], log(erfc(-x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 24, len(f.maker.fgraph.apply_nodes)
assert len(f.maker.fgraph.apply_nodes) == 24
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(-val)))
f = function([x], log(erfc(x)), mode=mode_fusion)
assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes)
assert len(f.maker.fgraph.apply_nodes) == 1
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert (
len(
......@@ -2921,10 +2899,6 @@ class TestLocalErfc:
.op.scalar_op.fgraph.apply_nodes
)
== 22
), len(
f.maker.fgraph.toposort()[0]
.fgraph.toposort()[0]
.op.scalar_op.fgraph.apply_nodes
)
# TODO: fix this problem
assert not (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论