提交 ff8b586c authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Move Softplus components to respective modules

Also adds an inplace version And fixed the jax code to be equivalent to the Aesara implementation
上级 0d658a6d
...@@ -1515,7 +1515,6 @@ class ProfileStats: ...@@ -1515,7 +1515,6 @@ class ProfileStats:
from aesara import scalar as aes from aesara import scalar as aes
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import Dot from aesara.tensor.math import Dot
from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
scalar_op_amdlibm_no_speed_up = [ scalar_op_amdlibm_no_speed_up = [
...@@ -1568,7 +1567,7 @@ class ProfileStats: ...@@ -1568,7 +1567,7 @@ class ProfileStats:
aes.Cosh, aes.Cosh,
aes.Sinh, aes.Sinh,
aes.Sigmoid, aes.Sigmoid,
ScalarSoftplus, aes.Softplus,
] ]
def get_scalar_ops(s): def get_scalar_ops(s):
......
...@@ -13,6 +13,7 @@ from aesara.configdefaults import config ...@@ -13,6 +13,7 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs from aesara.scan.utils import scan_args as ScanArgs
...@@ -53,7 +54,6 @@ from aesara.tensor.nlinalg import ( ...@@ -53,7 +54,6 @@ from aesara.tensor.nlinalg import (
QRIncomplete, QRIncomplete,
) )
from aesara.tensor.nnet.basic import LogSoftmax, Softmax from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve from aesara.tensor.slinalg import Cholesky, Solve
...@@ -191,12 +191,16 @@ def jax_funcify_LogSoftmax(op, **kwargs): ...@@ -191,12 +191,16 @@ def jax_funcify_LogSoftmax(op, **kwargs):
return log_softmax return log_softmax
@jax_funcify.register(ScalarSoftplus) @jax_funcify.register(Softplus)
def jax_funcify_ScalarSoftplus(op, **kwargs): def jax_funcify_Softplus(op, **kwargs):
def scalarsoftplus(x): def softplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x)))) # This expression is numerically equivalent to the Aesara one
# It just contains one "speed" optimization less than the Aesara counterpart
return jnp.where(
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x)))
)
return scalarsoftplus return softplus
@jax_funcify.register(Second) @jax_funcify.register(Second)
......
...@@ -975,3 +975,98 @@ class Sigmoid(UnaryScalarOp): ...@@ -975,3 +975,98 @@ class Sigmoid(UnaryScalarOp):
sigmoid = Sigmoid(upgrade_to_float, name="sigmoid") sigmoid = Sigmoid(upgrade_to_float, name="sigmoid")
class Softplus(UnaryScalarOp):
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
@staticmethod
def static_impl(x):
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
if x < -37.0:
return np.exp(x) if not_int8 else np.exp(x, signature="f")
elif x < 18.0:
return (
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
)
elif x < 33.3:
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else:
return x
def impl(self, x):
return Softplus.static_impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
return [gz * sigmoid(x)]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# The boundary constants were obtained by looking at the output of
# python commands like:
# import numpy, aesara
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
if node.inputs[0].type in float_types:
if node.inputs[0].type == float64:
return (
"""
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
else:
return (
"""
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
else:
raise NotImplementedError("only floatingpoint is implemented")
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
else:
return v
softplus = Softplus(upgrade_to_float, name="scalar_softplus")
...@@ -313,6 +313,11 @@ def sigmoid_inplace(x): ...@@ -313,6 +313,11 @@ def sigmoid_inplace(x):
"""Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit""" """Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit"""
@scalar_elemwise
def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise @scalar_elemwise
def second_inplace(a): def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
......
...@@ -1407,6 +1407,11 @@ def sigmoid(x): ...@@ -1407,6 +1407,11 @@ def sigmoid(x):
"""Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit""" """Logistic sigmoid function (1 / (1 + exp(x)), also known as expit or inverse logit"""
@scalar_elemwise
def softplus(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise @scalar_elemwise
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
...@@ -2834,6 +2839,7 @@ __all__ = [ ...@@ -2834,6 +2839,7 @@ __all__ = [
"i1", "i1",
"iv", "iv",
"sigmoid", "sigmoid",
"softplus",
"real", "real",
"imag", "imag",
"angle", "angle",
......
...@@ -43,4 +43,4 @@ from aesara.tensor.nnet.basic import ( ...@@ -43,4 +43,4 @@ from aesara.tensor.nnet.basic import (
softsign, softsign,
) )
from aesara.tensor.nnet.batchnorm import batch_normalization from aesara.tensor.nnet.batchnorm import batch_normalization
from aesara.tensor.nnet.sigm import hard_sigmoid, softplus, ultra_fast_sigmoid from aesara.tensor.nnet.sigm import hard_sigmoid, ultra_fast_sigmoid
...@@ -54,11 +54,11 @@ from aesara.tensor.math import ( ...@@ -54,11 +54,11 @@ from aesara.tensor.math import (
neg, neg,
or_, or_,
sigmoid, sigmoid,
softplus,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh, tensordot, true_div from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.nnet.sigm import softplus
from aesara.tensor.shape import shape, shape_padleft from aesara.tensor.shape import shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
......
...@@ -31,6 +31,7 @@ from aesara.tensor.math import ( ...@@ -31,6 +31,7 @@ from aesara.tensor.math import (
mul, mul,
neg, neg,
sigmoid, sigmoid,
softplus,
sub, sub,
true_div, true_div,
) )
...@@ -189,102 +190,6 @@ aesara.compile.optdb["uncanonicalize"].register( ...@@ -189,102 +190,6 @@ aesara.compile.optdb["uncanonicalize"].register(
) )
class ScalarSoftplus(aes.UnaryScalarOp):
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
def static_impl(x):
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
if x < -37.0:
return np.exp(x) if not_int8 else np.exp(x, signature="f")
elif x < 18.0:
return (
np.log1p(np.exp(x)) if not_int8 else np.log1p(np.exp(x, signature="f"))
)
elif x < 33.3:
return x + np.exp(-x) if not_int8 else x + np.exp(-x, signature="f")
else:
return x
def impl(self, x):
return ScalarSoftplus.static_impl(x)
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
return [gz * scalar_sigmoid(x)]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
# The boundary constants were obtained by looking at the output of
# python commands like:
# import numpy, aesara
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
if node.inputs[0].type == aes.float32 or node.inputs[0].type == aes.float16:
return (
"""
%(z)s = (
%(x)s < -103.0f ? 0.0 :
%(x)s < -37.0f ? exp(%(x)s) :
%(x)s < 18.0f ? log1p(exp(%(x)s)) :
%(x)s < 33.3f ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
elif node.inputs[0].type == aes.float64:
return (
"""
%(z)s = (
%(x)s < -745.0 ? 0.0 :
%(x)s < -37.0 ? exp(%(x)s) :
%(x)s < 18.0 ? log1p(exp(%(x)s)) :
%(x)s < 33.3 ? %(x)s + exp(-%(x)s) :
%(x)s
);
"""
% locals()
)
else:
raise NotImplementedError("only floatingpoint is implemented")
def c_code_cache_version(self):
v = super().c_code_cache_version()
if v:
return (2,) + v
else:
return v
scalar_softplus = ScalarSoftplus(aes.upgrade_to_float, name="scalar_softplus")
softplus = Elemwise(scalar_softplus, name="softplus")
pprint.assign(softplus, printing.FunctionPrinter("softplus"))
def _skip_mul_1(r): def _skip_mul_1(r):
if r.owner and r.owner.op == mul: if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
......
...@@ -31,7 +31,7 @@ from aesara.tensor.math import MaxAndArgmax ...@@ -31,7 +31,7 @@ from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, gammaln, log from aesara.tensor.math import clip, cosh, gammaln, log
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod, sigmoid from aesara.tensor.math import maximum, prod, sigmoid, softplus
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import RandomVariable, normal from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
...@@ -952,7 +952,7 @@ def test_nnet(): ...@@ -952,7 +952,7 @@ def test_nnet():
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.softplus(x) out = softplus(x)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
...@@ -5,11 +5,11 @@ import aesara.tensor as aet ...@@ -5,11 +5,11 @@ import aesara.tensor as aet
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.graph.toolbox import is_same_graph from aesara.graph.toolbox import is_same_graph
from aesara.tensor import sigmoid from aesara.scalar import Softplus
from aesara.tensor import sigmoid, softplus
from aesara.tensor.inplace import neg_inplace, sigmoid_inplace from aesara.tensor.inplace import neg_inplace, sigmoid_inplace
from aesara.tensor.math import clip, exp, log, mul, neg from aesara.tensor.math import clip, exp, log, mul, neg
from aesara.tensor.nnet.sigm import ( from aesara.tensor.nnet.sigm import (
ScalarSoftplus,
compute_mul, compute_mul,
hard_sigmoid, hard_sigmoid,
is_1pexp, is_1pexp,
...@@ -17,7 +17,6 @@ from aesara.tensor.nnet.sigm import ( ...@@ -17,7 +17,6 @@ from aesara.tensor.nnet.sigm import (
perform_sigm_times_exp, perform_sigm_times_exp,
register_local_1msigmoid, register_local_1msigmoid,
simplify_mul, simplify_mul,
softplus,
ultra_fast_sigmoid, ultra_fast_sigmoid,
) )
from aesara.tensor.shape import Reshape from aesara.tensor.shape import Reshape
...@@ -474,7 +473,7 @@ class TestSoftplusOpts: ...@@ -474,7 +473,7 @@ class TestSoftplusOpts:
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 3 assert len(topo) == 3
assert isinstance(topo[0].op.scalar_op, aesara.scalar.Neg) assert isinstance(topo[0].op.scalar_op, aesara.scalar.Neg)
assert isinstance(topo[1].op.scalar_op, ScalarSoftplus) assert isinstance(topo[1].op.scalar_op, Softplus)
assert isinstance(topo[2].op.scalar_op, aesara.scalar.Neg) assert isinstance(topo[2].op.scalar_op, aesara.scalar.Neg)
f(np.random.rand(54).astype(config.floatX)) f(np.random.rand(54).astype(config.floatX))
...@@ -485,7 +484,7 @@ class TestSoftplusOpts: ...@@ -485,7 +484,7 @@ class TestSoftplusOpts:
f = aesara.function([x], out, mode=self.m) f = aesara.function([x], out, mode=self.m)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 2 assert len(topo) == 2
assert isinstance(topo[0].op.scalar_op, ScalarSoftplus) assert isinstance(topo[0].op.scalar_op, Softplus)
assert isinstance(topo[1].op.scalar_op, aesara.scalar.Neg) assert isinstance(topo[1].op.scalar_op, aesara.scalar.Neg)
# assert check_stack_trace(f, ops_to_check='all') # assert check_stack_trace(f, ops_to_check='all')
f(np.random.rand(54, 11).astype(config.floatX)) f(np.random.rand(54, 11).astype(config.floatX))
...@@ -498,7 +497,7 @@ class TestSoftplusOpts: ...@@ -498,7 +497,7 @@ class TestSoftplusOpts:
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 3 assert len(topo) == 3
assert aet.is_flat(topo[0].outputs[0]) assert aet.is_flat(topo[0].outputs[0])
assert isinstance(topo[1].op.scalar_op, ScalarSoftplus) assert isinstance(topo[1].op.scalar_op, Softplus)
assert isinstance(topo[2].op.scalar_op, aesara.scalar.Neg) assert isinstance(topo[2].op.scalar_op, aesara.scalar.Neg)
f(np.random.rand(54, 11).astype(config.floatX)) f(np.random.rand(54, 11).astype(config.floatX))
...@@ -511,7 +510,7 @@ class TestSoftplusOpts: ...@@ -511,7 +510,7 @@ class TestSoftplusOpts:
assert any( assert any(
isinstance( isinstance(
getattr(node.op, "scalar_op", None), getattr(node.op, "scalar_op", None),
ScalarSoftplus, Softplus,
) )
for node in topo for node in topo
) )
...@@ -531,7 +530,7 @@ class TestSoftplusOpts: ...@@ -531,7 +530,7 @@ class TestSoftplusOpts:
# assert check_stack_trace(f, ops_to_check='all') # assert check_stack_trace(f, ops_to_check='all')
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
assert isinstance(topo[0].op.scalar_op, ScalarSoftplus) assert isinstance(topo[0].op.scalar_op, Softplus)
f(np.random.rand(54).astype(config.floatX)) f(np.random.rand(54).astype(config.floatX))
......
...@@ -82,11 +82,10 @@ from aesara.tensor.math import ( ...@@ -82,11 +82,10 @@ from aesara.tensor.math import (
) )
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import round as aet_round from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sin, sinh, sqr, sqrt, sub from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.nnet.sigm import softplus
from aesara.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape from aesara.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论