提交 12b6144c authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Enable `test_local_loc_erfc`

上级 67ff0800
......@@ -2869,7 +2869,6 @@ class TestLocalErfc:
f = function([x], erfc(-1.0 * x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
@pytest.mark.xfail()
def test_local_log_erfc(self):
val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30]
if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]:
......@@ -2886,42 +2885,27 @@ class TestLocalErfc:
mode_fusion.check_isfinite = False
f = function([x], log(erfc(x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23
assert len(f.maker.fgraph.apply_nodes) == 22
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val)))
f = function([x], log(erfc(-x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 24
assert len(f.maker.fgraph.apply_nodes) == 23
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(-val)))
f = function([x], log(erfc(x)), mode=mode_fusion)
assert len(f.maker.fgraph.apply_nodes) == 1
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert (
len(
f.maker.fgraph.toposort()[0]
.fgraph.toposort()[0]
.op.scalar_op.fgraph.apply_nodes
)
== 22
)
# TODO: fix this problem
assert not (
config.floatX == "float32"
and config.mode
in [
"DebugMode",
"DEBUG_MODE",
]
), (
"The python code upcast somewhere internally "
"some value of float32 to python float for "
"part of its computation. That make that the "
"c and python code don't generate the same value. "
"You can ignore this error."
)
assert all(np.isfinite(f(val)))
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 22
# TODO: fix this problem: The python code upcast somewhere internally
# some value of float32 to python float for part of its computation.
# That makes the c and python code generate sligtly different values
if not (
config.floatX == "float32" and config.mode in ["DebugMode", "DEBUG_MODE"]
):
assert all(np.isfinite(f(val)))
@np.errstate(divide="ignore", invalid="ignore")
def test_local_grad_log_erfc_neg(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论