提交 57a1eb72 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Register `log1p_neg_sigmoid` rewrite in specialize

上级 a7840844
...@@ -3096,7 +3096,8 @@ log1p_neg_sigmoid = PatternSub( ...@@ -3096,7 +3096,8 @@ log1p_neg_sigmoid = PatternSub(
register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus") register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus")
register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus") register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus")
register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus") register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus")
register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid,") register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid")
def is_1pexp(t, only_process_constants=True): def is_1pexp(t, only_process_constants=True):
......
...@@ -4456,6 +4456,19 @@ class TestSoftplusOpts: ...@@ -4456,6 +4456,19 @@ class TestSoftplusOpts:
assert isinstance(topo[0].op.scalar_op, aesara.scalar.Softplus) assert isinstance(topo[0].op.scalar_op, aesara.scalar.Softplus)
f(np.random.random((54)).astype(config.floatX)) f(np.random.random((54)).astype(config.floatX))
def test_log1p_neg_sigmoid_to_softpuls(self):
x = scalar()
out = log1p(-sigmoid(x))
f = aesara.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op.scalar_op, aesara.scalar.Softplus)
assert isinstance(topo[1].op.scalar_op, aesara.scalar.Neg)
# This value would underflow to -inf without rewrite
assert np.isclose(f(37.0), -37.0)
class TestSigmoidUtils: class TestSigmoidUtils:
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论