提交 6176f28c authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Remove unhelpful assert error messages in `test_math_opt.TestLocalErf(c)`

上级 4fd52e4e
...@@ -2751,17 +2751,11 @@ class TestLocalErf: ...@@ -2751,17 +2751,11 @@ class TestLocalErf:
x = vector() x = vector()
f = function([x], 1 + erf(x), mode=self.mode) f = function([x], 1 + erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc]
mul,
erfc,
], f.maker.fgraph.toposort()
f(val) f(val)
f = function([x], erf(x) + 1, mode=self.mode) f = function([x], erf(x) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [mul, erfc]
mul,
erfc,
], f.maker.fgraph.toposort()
f(val) f(val)
f = function([x], erf(x) + 2, mode=self.mode) f = function([x], erf(x) + 2, mode=self.mode)
...@@ -2777,29 +2771,23 @@ class TestLocalErf: ...@@ -2777,29 +2771,23 @@ class TestLocalErf:
x = vector() x = vector()
f = function([x], 1 - erf(x), mode=self.mode) f = function([x], 1 - erf(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
erfc
], f.maker.fgraph.toposort()
f(val) f(val)
f = function([x], 1 + (-erf(x)), mode=self.mode) f = function([x], 1 + (-erf(x)), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
erfc
], f.maker.fgraph.toposort()
f = function([x], (-erf(x)) + 1, mode=self.mode) f = function([x], (-erf(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erfc]
erfc
], f.maker.fgraph.toposort()
f = function([x], 2 - erf(x), mode=self.mode) f = function([x], 2 - erf(x), mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2, f.maker.fgraph.toposort() assert len(topo) == 2
assert topo[0].op == erf, f.maker.fgraph.toposort() assert topo[0].op == erf
assert isinstance(topo[1].op, Elemwise), f.maker.fgraph.toposort() assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, aes.Add) or isinstance( assert isinstance(topo[1].op.scalar_op, aes.Add) or isinstance(
topo[1].op.scalar_op, aes.Sub topo[1].op.scalar_op, aes.Sub
), f.maker.fgraph.toposort() )
def test_local_erf_minus_one(self): def test_local_erf_minus_one(self):
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX) val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
...@@ -2847,22 +2835,18 @@ class TestLocalErfc: ...@@ -2847,22 +2835,18 @@ class TestLocalErfc:
x = vector("x") x = vector("x")
f = function([x], 1 - erfc(x), mode=self.mode) f = function([x], 1 - erfc(x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
erf
], f.maker.fgraph.toposort()
f(val) f(val)
f = function([x], (-erfc(x)) + 1, mode=self.mode) f = function([x], (-erfc(x)) + 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
erf
], f.maker.fgraph.toposort()
f = function([x], 2 - erfc(x), mode=self.mode) f = function([x], 2 - erfc(x), mode=self.mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2, f.maker.fgraph.toposort() assert len(topo) == 2
assert topo[0].op == erfc, f.maker.fgraph.toposort() assert topo[0].op == erfc
assert isinstance(topo[1].op, Elemwise), f.maker.fgraph.toposort() assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, aes.Sub), f.maker.fgraph.toposort() assert isinstance(topo[1].op.scalar_op, aes.Sub)
def test_local_erf_neg_minus_one(self): def test_local_erf_neg_minus_one(self):
# test opt: (-1)+erfc(-x)=>erf(x) # test opt: (-1)+erfc(-x)=>erf(x)
...@@ -2870,20 +2854,14 @@ class TestLocalErfc: ...@@ -2870,20 +2854,14 @@ class TestLocalErfc:
x = vector("x") x = vector("x")
f = function([x], -1 + erfc(-x), mode=self.mode) f = function([x], -1 + erfc(-x), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
erf
], f.maker.fgraph.toposort()
f(val) f(val)
f = function([x], erfc(-x) - 1, mode=self.mode) f = function([x], erfc(-x) - 1, mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
erf
], f.maker.fgraph.toposort()
f = function([x], erfc(-x) + (-1), mode=self.mode) f = function([x], erfc(-x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [ assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
erf
], f.maker.fgraph.toposort()
@pytest.mark.xfail() @pytest.mark.xfail()
def test_local_log_erfc(self): def test_local_log_erfc(self):
...@@ -2902,17 +2880,17 @@ class TestLocalErfc: ...@@ -2902,17 +2880,17 @@ class TestLocalErfc:
mode_fusion.check_isfinite = False mode_fusion.check_isfinite = False
f = function([x], log(erfc(x)), mode=mode) f = function([x], log(erfc(x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 23
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
f = function([x], log(erfc(-x)), mode=mode) f = function([x], log(erfc(-x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 24, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 24
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(-val))) assert all(np.isfinite(f(-val)))
f = function([x], log(erfc(x)), mode=mode_fusion) f = function([x], log(erfc(x)), mode=mode_fusion)
assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 1
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert ( assert (
len( len(
...@@ -2921,10 +2899,6 @@ class TestLocalErfc: ...@@ -2921,10 +2899,6 @@ class TestLocalErfc:
.op.scalar_op.fgraph.apply_nodes .op.scalar_op.fgraph.apply_nodes
) )
== 22 == 22
), len(
f.maker.fgraph.toposort()[0]
.fgraph.toposort()[0]
.op.scalar_op.fgraph.apply_nodes
) )
# TODO: fix this problem # TODO: fix this problem
assert not ( assert not (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论