提交 ff8b586c authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Move Softplus components to respective modules

Also adds an inplace version And fixed the jax code to be equivalent to the Aesara implementation
上级 0d658a6d
......@@ -1515,7 +1515,6 @@ class ProfileStats:
from aesara import scalar as aes
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import Dot
from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable
scalar_op_amdlibm_no_speed_up = [
......@@ -1568,7 +1567,7 @@ class ProfileStats:
aes.Cosh,
aes.Sinh,
aes.Sigmoid,
ScalarSoftplus,
aes.Softplus,
]
def get_scalar_ops(s):
......
......@@ -13,6 +13,7 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs
......@@ -53,7 +54,6 @@ from aesara.tensor.nlinalg import (
QRIncomplete,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
......@@ -191,12 +191,16 @@ def jax_funcify_LogSoftmax(op, **kwargs):
return log_softmax
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op, **kwargs):
def scalarsoftplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x))))
@jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs):
def softplus(x):
# This expression is numerically equivalent to the Aesara one
# It just contains one "speed" optimization less than the Aesara counterpart
return jnp.where(
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x)))
)
return scalarsoftplus
return softplus
@jax_funcify.register(Second)
......
......@@ -975,3 +975,98 @@ class Sigmoid(UnaryScalarOp):
sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
class Softplus(UnaryScalarOp):
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
@staticmethod
def static_impl(x):
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
if x < -37.0:
return np.exp(x) if not_int8 else np.exp(x, signature="f")
elif x < 18.0:
return (
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
)
elif x < 33.3:
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else:
return x
def impl(self, x):
return Softplus.static_impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
return [gz * sigmoid(x)]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# The boundary constants were obtained by looking at the output of
# python commands like:
# import numpy, aesara
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
if node.inputs[0].type in float_types:
if node.inputs[0].type == float64:
return (
"""
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
else:
return (
"""
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
else:
raise NotImplementedError("only floatingpoint is implemented")
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
else:
return v
softplus = Softplus(upgrade_to_float, name="scalar_softplus")
......@@ -313,6 +313,11 @@ def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit"""
@scalar_elemwise
def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def second_inplace(a):
"""Fill `a` with `b`"""
......
......@@ -1407,6 +1407,11 @@ def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit"""
@scalar_elemwise
def softplus(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def real(z):
"""Return real component of complex-valued tensor `z`"""
......@@ -2834,6 +2839,7 @@ __all__ = [
"i1",
"iv",
"sigmoid",
"softplus",
"real",
"imag",
"angle",
......
......@@ -43,4 +43,4 @@ from aesara.tensor.nnet.basic import (
softsign,
)
from aesara.tensor.nnet.batchnorm import batch_normalization
from aesara.tensor.nnet.sigm import hard_sigmoid, softplus, ultra_fast_sigmoid
from aesara.tensor.nnet.sigm import hard_sigmoid, ultra_fast_sigmoid
......@@ -54,11 +54,11 @@ from aesara.tensor.math import (
neg,
or_,
sigmoid,
softplus,
)
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.nnet.sigm import softplus
from aesara.tensor.shape import shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
from aesara.tensor.type import (
......
......@@ -31,6 +31,7 @@ from aesara.tensor.math import (
mul,
neg,
sigmoid,
softplus,
sub,
true_div,
)
......@@ -189,102 +190,6 @@ aesara.compile.optdb["uncanonicalize"].register(
)
class ScalarSoftplus(aes.UnaryScalarOp):
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
def static_impl(x):
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
if x < -37.0:
return np.exp(x) if not_int8 else np.exp(x, signature="f")
elif x < 18.0:
return (
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
)
elif x < 33.3:
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else:
return x
def impl(self, x):
return ScalarSoftplus.static_impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
return [gz * scalar_sigmoid(x)]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# The boundary constants were obtained by looking at the output of
# python commands like:
# import numpy, aesara
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
if node.inputs[0].type == aes.float32 or node.inputs[0].type == aes.float16:
return (
"""
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
elif node.inputs[0].type == aes.float64:
return (
"""
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
else:
raise NotImplementedError("only floatingpoint is implemented")
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
else:
return v
scalar_softplus = ScalarSoftplus(aes.upgrade_to_float, name="scalar_softplus")
softplus = Elemwise(scalar_softplus, name="softplus")
pprint.assign(softplus, printing.FunctionPrinter("softplus"))
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
......
......@@ -31,7 +31,7 @@ from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, gammaln, log
from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod, sigmoid
from aesara.tensor.math import maximum, prod, sigmoid, softplus
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream
......@@ -952,7 +952,7 @@ def test_nnet():
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.softplus(x)
out = softplus(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -5,11 +5,11 @@ import aesara.tensor as aet
from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace
from aesara.graph.toolbox import is_same_graph
from aesara.tensor import sigmoid
from aesara.scalar import Softplus
from aesara.tensor import sigmoid, softplus
from aesara.tensor.inplace import neg_inplace, sigmoid_inplace
from aesara.tensor.math import clip, exp, log, mul, neg
from aesara.tensor.nnet.sigm import (
ScalarSoftplus,
compute_mul,
hard_sigmoid,
is_1pexp,
......@@ -17,7 +17,6 @@ from aesara.tensor.nnet.sigm import (
perform_sigm_times_exp,
register_local_1msigmoid,
simplify_mul,
softplus,
ultra_fast_sigmoid,
)
from aesara.tensor.shape import Reshape
......@@ -474,7 +473,7 @@ class TestSoftplusOpts:
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[0].op.scalar_op, aesara.scalar.Neg)
assert isinstance(topo[1].op.scalar_op, ScalarSoftplus)
assert isinstance(topo[1].op.scalar_op, Softplus)
assert isinstance(topo[2].op.scalar_op, aesara.scalar.Neg)
f(np.random.rand(54).astype(config.floatX))
......@@ -485,7 +484,7 @@ class TestSoftplusOpts:
f = aesara.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op.scalar_op, ScalarSoftplus)
assert isinstance(topo[0].op.scalar_op, Softplus)
assert isinstance(topo[1].op.scalar_op, aesara.scalar.Neg)
# assert check_stack_trace(f, ops_to_check='all')
f(np.random.rand(54, 11).astype(config.floatX))
......@@ -498,7 +497,7 @@ class TestSoftplusOpts:
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert aet.is_flat(topo[0].outputs[0])
assert isinstance(topo[1].op.scalar_op, ScalarSoftplus)
assert isinstance(topo[1].op.scalar_op, Softplus)
assert isinstance(topo[2].op.scalar_op, aesara.scalar.Neg)
f(np.random.rand(54, 11).astype(config.floatX))
......@@ -511,7 +510,7 @@ class TestSoftplusOpts:
assert any(
isinstance(
getattr(node.op, "scalar_op", None),
ScalarSoftplus,
Softplus,
)
for node in topo
)
......@@ -531,7 +530,7 @@ class TestSoftplusOpts:
# assert check_stack_trace(f, ops_to_check='all')
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, ScalarSoftplus)
assert isinstance(topo[0].op.scalar_op, Softplus)
f(np.random.rand(54).astype(config.floatX))
......
......@@ -82,11 +82,10 @@ from aesara.tensor.math import (
)
from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sin, sinh, sqr, sqrt, sub
from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.nnet.sigm import softplus
from aesara.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论