提交 12b6144c authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Enable `test_local_loc_erfc`

上级 67ff0800
...@@ -2869,7 +2869,6 @@ class TestLocalErfc: ...@@ -2869,7 +2869,6 @@ class TestLocalErfc:
f = function([x], erfc(-1.0 * x) + (-1), mode=self.mode) f = function([x], erfc(-1.0 * x) + (-1), mode=self.mode)
assert [n.op for n in f.maker.fgraph.toposort()] == [erf] assert [n.op for n in f.maker.fgraph.toposort()] == [erf]
@pytest.mark.xfail()
def test_local_log_erfc(self): def test_local_log_erfc(self):
val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30] val = [-30, -27, -26, -11, -10, -3, -2, -1, 0, 1, 2, 3, 10, 11, 26, 27, 28, 30]
if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]: if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]:
...@@ -2886,42 +2885,27 @@ class TestLocalErfc: ...@@ -2886,42 +2885,27 @@ class TestLocalErfc:
mode_fusion.check_isfinite = False mode_fusion.check_isfinite = False
f = function([x], log(erfc(x)), mode=mode) f = function([x], log(erfc(x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23 assert len(f.maker.fgraph.apply_nodes) == 22
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
f = function([x], log(erfc(-x)), mode=mode) f = function([x], log(erfc(-x)), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 24 assert len(f.maker.fgraph.apply_nodes) == 23
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(-val))) assert all(np.isfinite(f(-val)))
f = function([x], log(erfc(x)), mode=mode_fusion) f = function([x], log(erfc(x)), mode=mode_fusion)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert ( assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 22
len(
f.maker.fgraph.toposort()[0] # TODO: fix this problem: The python code upcast somewhere internally
.fgraph.toposort()[0] # some value of float32 to python float for part of its computation.
.op.scalar_op.fgraph.apply_nodes # That makes the c and python code generate sligtly different values
) if not (
== 22 config.floatX == "float32" and config.mode in ["DebugMode", "DEBUG_MODE"]
) ):
# TODO: fix this problem assert all(np.isfinite(f(val)))
assert not (
config.floatX == "float32"
and config.mode
in [
"DebugMode",
"DEBUG_MODE",
]
), (
"The python code upcast somewhere internally "
"some value of float32 to python float for "
"part of its computation. That make that the "
"c and python code don't generate the same value. "
"You can ignore this error."
)
assert all(np.isfinite(f(val)))
@np.errstate(divide="ignore", invalid="ignore") @np.errstate(divide="ignore", invalid="ignore")
def test_local_grad_log_erfc_neg(self): def test_local_grad_log_erfc_neg(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论