提交 8968a387 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Harmonize softplus implementations

上级 671a821d
......@@ -125,10 +125,18 @@ def jax_funcify_Psi(op, node, **kwargs):
@jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs):
def softplus(x):
# This expression is numerically equivalent to the PyTensor one
# It just contains one "speed" optimization less than the PyTensor counterpart
return jnp.where(
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x)))
x < -37.0,
jnp.exp(x),
jnp.where(
x < 18.0,
jnp.log1p(jnp.exp(x)),
jnp.where(
x < 33.3,
x + jnp.exp(-x),
x,
),
),
)
return softplus
......@@ -6,6 +6,7 @@ As SciPy is not always available, we treat them separately.
import os
import warnings
from textwrap import dedent
import numpy as np
import scipy.special
......@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
This function is numerically faster than the naive approach, and does not overflow
for large values of x.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
......@@ -1172,44 +1174,30 @@ class Softplus(UnaryScalarOp):
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# The boundary constants were obtained by looking at the output of
# python commands like:
# import numpy, pytensor
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
# We use the same limits for all precisions, which may be suboptimal. The reference
# paper only looked at double precision
if node.inputs[0].type in float_types:
if node.inputs[0].type == float64:
return (
"""
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
return dedent(
f"""
{z} = (
{x} < -37.0 ? exp({x}) :
{x} < 18.0 ? log1p(exp({x})) :
{x} < 33.3 ? {x} + exp(-{x}) :
{x}
);
"""
% locals()
)
else:
return (
"""
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
return dedent(
f"""
{z} = (
{x} < -37.0f ? exp({x}) :
{x} < 18.0f ? log1p(exp({x})) :
{x} < 33.3f ? {x} + exp(-{x}) :
{x}
);
"""
% locals()
)
else:
raise NotImplementedError("only floatingpoint is implemented")
......@@ -1217,7 +1205,7 @@ class Softplus(UnaryScalarOp):
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
return (3,) + v
else:
return v
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论