Unverified 提交 08c63f71 authored 作者: ricardoV94's avatar ricardoV94 提交者: GitHub

Implement improved numerically-stable softplus (#262)

* Implement improved Softplus * Extend gradient test to wider range * Add numerical accuracy test * Fix indentation
上级 f123525c
...@@ -111,6 +111,13 @@ class TestSoftplus: ...@@ -111,6 +111,13 @@ class TestSoftplus:
def test_elemwise(self): def test_elemwise(self):
utt.verify_grad(softplus, [np.random.rand(3, 4)]) utt.verify_grad(softplus, [np.random.rand(3, 4)])
def test_accuracy(self):
# Test all aproximations are working (cutoff points are -37, 18, 33.3)
x_test = np.array([-40.0, -17.5, 17.5, 18.5, 40.0])
y_th = softplus(x_test).eval()
y_np = np.log1p(np.exp(x_test))
np.testing.assert_allclose(y_th, y_np, rtol=10e-10)
class TestSigmoidOpts: class TestSigmoidOpts:
def get_mode(self, excluding=None): def get_mode(self, excluding=None):
......
...@@ -348,22 +348,35 @@ theano.compile.optdb["uncanonicalize"].register( ...@@ -348,22 +348,35 @@ theano.compile.optdb["uncanonicalize"].register(
class ScalarSoftplus(scalar.UnaryScalarOp): class ScalarSoftplus(scalar.UnaryScalarOp):
""" r"""
This helps numerical stability. Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
""" """
@staticmethod @staticmethod
def static_impl(x): def static_impl(x):
if x < -30.0:
return 0.0
if x > 30.0:
return x
# If x is an int8 or uint8, numpy.exp will compute the result in # If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32. # half-precision (float16), where we want float32.
x_dtype = str(getattr(x, "dtype", "")) not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
if x_dtype in ("int8", "uint8"): if x < -37.0:
return np.log1p(np.exp(x, sig="f")) return np.exp(x) if not_int8 else np.exp(x, signature="f")
return np.log1p(np.exp(x)) elif x < 18.0:
return (
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
)
elif x < 33.3:
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else:
return x
def impl(self, x): def impl(self, x):
return ScalarSoftplus.static_impl(x) return ScalarSoftplus.static_impl(x)
...@@ -378,11 +391,13 @@ class ScalarSoftplus(scalar.UnaryScalarOp): ...@@ -378,11 +391,13 @@ class ScalarSoftplus(scalar.UnaryScalarOp):
(z,) = out (z,) = out
# These constants were obtained by looking at the output of # These constants were obtained by looking at the output of
# python commands like: # python commands like:
# import numpy, theano
# dt='float32' # or float64
# for i in range(750): # for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt)))) # print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the boundary checks prevent us from generating inf # the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway
# float16 limits: -17.0, 6.0
# We use the float32 limits for float16 for now as the # We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway. # computation will happen in float32 anyway.
if ( if (
...@@ -390,12 +405,28 @@ class ScalarSoftplus(scalar.UnaryScalarOp): ...@@ -390,12 +405,28 @@ class ScalarSoftplus(scalar.UnaryScalarOp):
or node.inputs[0].type == scalar.float16 or node.inputs[0].type == scalar.float16
): ):
return ( return (
"""%(z)s = %(x)s < -103.0f ? 0.0 : %(x)s > 14.0f ? %(x)s : log1p(exp(%(x)s));""" """
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals() % locals()
) )
elif node.inputs[0].type == scalar.float64: elif node.inputs[0].type == scalar.float64:
return ( return (
"""%(z)s = %(x)s < -745.0 ? 0.0 : %(x)s > 16.0 ? %(x)s : log1p(exp(%(x)s));""" """
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals() % locals()
) )
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论