提交 8968a387 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Harmonize softplus implementations

上级 671a821d
...@@ -125,10 +125,18 @@ def jax_funcify_Psi(op, node, **kwargs): ...@@ -125,10 +125,18 @@ def jax_funcify_Psi(op, node, **kwargs):
@jax_funcify.register(Softplus) @jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs): def jax_funcify_Softplus(op, **kwargs):
def softplus(x): def softplus(x):
# This expression is numerically equivalent to the PyTensor one
# It just contains one "speed" optimization less than the PyTensor counterpart
return jnp.where( return jnp.where(
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x))) x < -37.0,
jnp.exp(x),
jnp.where(
x < 18.0,
jnp.log1p(jnp.exp(x)),
jnp.where(
x < 33.3,
x + jnp.exp(-x),
x,
),
),
) )
return softplus return softplus
...@@ -6,6 +6,7 @@ As SciPy is not always available, we treat them separately. ...@@ -6,6 +6,7 @@ As SciPy is not always available, we treat them separately.
import os import os
import warnings import warnings
from textwrap import dedent
import numpy as np import numpy as np
import scipy.special import scipy.special
...@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp): ...@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
r""" r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach. This function is numerically faster than the naive approach, and does not overflow
for large values of x.
For details, see For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
...@@ -1172,44 +1174,30 @@ class Softplus(UnaryScalarOp): ...@@ -1172,44 +1174,30 @@ class Softplus(UnaryScalarOp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
(x,) = inp (x,) = inp
(z,) = out (z,) = out
# The boundary constants were obtained by looking at the output of # We use the same limits for all precisions, which may be suboptimal. The reference
# python commands like: # paper only looked at double precision
# import numpy, pytensor
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
if node.inputs[0].type == float64: if node.inputs[0].type == float64:
return ( return dedent(
""" f"""
%(z)s = ( {z} = (
%(x)s < -745.0 ? 0.0 : {x} < -37.0 ? exp({x}) :
%(x)s < -37.0 ? exp(%(x)s) : {x} < 18.0 ? log1p(exp({x})) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) : {x} < 33.3 ? {x} + exp(-{x}) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) : {x}
%(x)s
); );
""" """
% locals()
) )
else: else:
return ( return dedent(
""" f"""
%(z)s = ( {z} = (
%(x)s < -103.0f ? 0.0 : {x} < -37.0f ? exp({x}) :
%(x)s < -37.0f ? exp(%(x)s) : {x} < 18.0f ? log1p(exp({x})) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) : {x} < 33.3f ? {x} + exp(-{x}) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) : {x}
%(x)s
); );
""" """
% locals()
) )
else: else:
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
...@@ -1217,7 +1205,7 @@ class Softplus(UnaryScalarOp): ...@@ -1217,7 +1205,7 @@ class Softplus(UnaryScalarOp):
def c_code_cache_version(self): def c_code_cache_version(self):
v = super().c_code_cache_version() v = super().c_code_cache_version()
if v: if v:
return (2,) + v return (3,) + v
else: else:
return v return v
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论