提交 ec51faa6 authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Move sigmoid opt to math_opt

上级 ff8b586c
...@@ -4069,7 +4069,7 @@ def local_flatten_lift(fgraph, node): ...@@ -4069,7 +4069,7 @@ def local_flatten_lift(fgraph, node):
Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x)) Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
This optimization is needed by optimization This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten. log1msigm_to_softplus to get applied when there is a flatten.
""" """
if ( if (
...@@ -4295,7 +4295,7 @@ def local_reshape_lift(fgraph, node): ...@@ -4295,7 +4295,7 @@ def local_reshape_lift(fgraph, node):
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
This optimization is needed by optimization This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a reshape. log1msigm_to_softplus to get applied when there is a reshape.
""" """
if ( if (
......
...@@ -82,7 +82,7 @@ from aesara.tensor.math import ( ...@@ -82,7 +82,7 @@ from aesara.tensor.math import (
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, mul, neg from aesara.tensor.math import maximum, mul, neg
from aesara.tensor.math import pow as aet_pow from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod, sgn, sqr, sqrt, sub from aesara.tensor.math import prod, sgn, sigmoid, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import true_div from aesara.tensor.math import true_div
from aesara.tensor.shape import Shape, Shape_i from aesara.tensor.shape import Shape, Shape_i
...@@ -2993,3 +2993,656 @@ fuse_seqopt.register( ...@@ -2993,3 +2993,656 @@ fuse_seqopt.register(
"fast_run", "fast_run",
"fusion", "fusion",
) )
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
if len(not_is_1) == 1:
return not_is_1[0]
def _is_1(expr):
"""
Returns
-------
bool
True iff expr is a constant close to 1.
"""
try:
v = get_scalar_constant_value(expr)
return np.allclose(v, 1)
except NotScalarConstantError:
return False
logsigm_to_softplus = PatternSub(
(log, (sigmoid, "x")),
(neg, (softplus, (neg, "x"))),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
)
log1msigm_to_softplus = PatternSub(
(log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
(neg, (softplus, "x")),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
)
log1pexp_to_softplus = PatternSub(
(log1p, (exp, "x")),
(softplus, "x"),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
)
log1p_neg_sigmoid = PatternSub(
(log1p, (neg, (sigmoid, "x"))),
(neg, (softplus, "x")),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
)
register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus")
register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus")
register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus")
register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid,")
def is_1pexp(t, only_process_constants=True):
"""
Returns
-------
object
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
"""
if t.owner and t.owner.op == add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
t.owner.inputs, only_process_constants=only_process_constants
)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1:
maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == exp:
# Verify that the constant terms sum to 1.
if scalars:
scal_sum = scalars[0]
for s in scalars[1:]:
scal_sum = scal_sum + s
if np.allclose(scal_sum, 1):
return False, maybe_exp.owner.inputs[0]
# Before 7987b51 there used to be a bug where *any* constant
# was considered as if it was equal to 1, and thus this
# function would incorrectly identify it as (1 + exp(x)).
if config.warn__identify_1pexp_bug:
warnings.warn(
"Although your current code is fine, please note that "
"Aesara versions prior to 0.5 (more specifically, "
"prior to commit 7987b51 on 2011-12-18) may have "
"yielded an incorrect result. To remove this warning, "
"either set the `warn__identify_1pexp_bug` config "
"option to False, or `warn__ignore_bug_before` to at "
"least '0.4.1'."
)
return None
def is_exp(var):
"""
Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
Parameters
----------
var
The Variable to analyze.
Returns
-------
tuple
A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var`
cannot be cast into either form, then return `None`.
"""
_neg = False
neg_info = is_neg(var)
if neg_info is not None:
_neg = True
var = neg_info
if var.owner and var.owner.op == exp:
return _neg, var.owner.inputs[0]
def is_mul(var):
"""
Match a variable with `x * y * z * ...`.
Parameters
----------
var
The Variable to analyze.
Returns
-------
object
A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
"""
if var.owner and var.owner.op == mul:
return var.owner.inputs
else:
return None
def partition_num_or_denom(r, f):
if r.owner and r.owner.op == mul:
a = r.owner.inputs
else:
a = [r]
# ugly 2.4-compatible thing
f_terms = []
_neg = False
rest = []
for t in a:
f_t = f(t)
if f_t is None:
rest.append(t)
else:
neg_t, f_t = f_t
f_terms.append(f_t)
_neg ^= neg_t # bit flip if neg_t is true
return f_terms, rest, _neg
def is_neg(var):
"""
Match a variable with the `-x` pattern.
Parameters
----------
var
The Variable to analyze.
Returns
-------
object
`x` if `var` is of the form `-x`, or None otherwise.
"""
var_node = var.owner
if not var_node:
return None
# First match against `neg`.
if var_node.op == neg:
return var_node.inputs[0]
# Then match against a multiplication by -1.
if var_node.op == mul and len(var_node.inputs) >= 2:
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
# Found a multiplication by -1.
if len(var_node.inputs) == 2:
# Only return the other input.
return var_node.inputs[1 - idx]
else:
# Return the multiplication of all other inputs.
return mul(*(var_node.inputs[0:idx] + var_node.inputs[idx + 1 :]))
# No match.
return None
@register_stabilize
@local_optimizer([true_div])
def local_exp_over_1_plus_exp(fgraph, node):
"""
exp(x)/(1+exp(x)) -> sigm(x)
c/(1+exp(x)) -> c*sigm(-x)
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
if node.op == true_div:
# find all the exp() terms in the numerator
num, denom = node.inputs
num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp)
sigmoids = []
for t in denom_1pexp:
if t in num_exp_x:
# case: exp(x) /(1+exp(x))
sigmoids.append(sigmoid(t))
del num_exp_x[num_exp_x.index(t)]
else:
# case: 1/(1+exp(x))
sigmoids.append(sigmoid(-t))
copy_stack_trace(node.outputs[0], sigmoids[-1])
if not sigmoids: # we didn't find any. abort
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else:
new_num = mul(*new_num)
if num_neg ^ denom_neg:
new_num = -new_num
copy_stack_trace(num, new_num)
if len(denom_rest) == 0:
return [new_num]
elif len(denom_rest) == 1:
out = new_num / denom_rest[0]
else:
out = new_num / mul(*denom_rest)
copy_stack_trace(node.outputs[0], out)
return [out]
def parse_mul_tree(root):
"""
Parse a tree of multiplications starting at the given root.
Parameters
----------
root
The variable at the root of the tree.
Returns
-------
object
A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs.
Each input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Examples
--------
x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]]
"""
# Is it a multiplication?
mul_info = is_mul(root)
if mul_info is None:
# Is it a negation?
neg_info = is_neg(root)
if neg_info is None:
# Keep the root "as is".
return [False, root]
else:
# Recurse, inverting the negation.
neg, sub_tree = parse_mul_tree(neg_info)
return [not neg, sub_tree]
else:
# Recurse into inputs.
return [False, list(map(parse_mul_tree, mul_info))]
def replace_leaf(arg, leaves, new_leaves, op, neg):
"""
Attempt to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`.
Parameters
----------
arg
The argument of the leaf we are looking for.
leaves
List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
new_leaves
If a replacement occurred, then the leaf is removed from `leaves`
and added to the list `new_leaves` (after being modified by `op`).
op
A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
neg : bool
If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
Returns
-------
bool
True if a replacement occurred, or False otherwise.
"""
for idx, x in enumerate(leaves):
if x[0] == arg:
x[1][0] ^= neg
x[1][1] = op(arg)
leaves.pop(idx)
new_leaves.append(x)
return True
return False
def simplify_mul(tree):
"""
Simplify a multiplication tree.
Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
object
A multiplication tree computing the same output as `tree` but without
useless multiplications by 1 nor -1 (identified by leaves of the form
[False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
"""
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs.
s_inputs = []
for s_i in map(simplify_mul, inputs):
if s_i[1] is None:
# Multiplication by +/-1.
neg ^= s_i[0]
else:
s_inputs.append(s_i)
if not s_inputs:
# The multiplication is empty.
rval = [neg, None]
elif len(s_inputs) == 1:
# The multiplication has a single input.
s_inputs[0][0] ^= neg
rval = s_inputs[0]
else:
rval = [neg, s_inputs]
else:
rval = tree
# print 'simplify_mul: %s -> %s' % (tree, rval)
return rval
def compute_mul(tree):
"""
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree.
Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
object
A Variable that computes the multiplication represented by the tree.
"""
neg, inputs = tree
if inputs is None:
raise AssertionError(
"Function `compute_mul` found a missing leaf, did you forget to "
"call `simplify_mul` on the tree first?"
)
elif isinstance(inputs, list):
# Recurse through inputs.
rval = mul(*list(map(compute_mul, inputs)))
else:
rval = inputs
if neg:
rval = -rval
return rval
def perform_sigm_times_exp(
tree,
exp_x=None,
exp_minus_x=None,
sigm_x=None,
sigm_minus_x=None,
parent=None,
child_idx=None,
full_tree=None,
):
"""
Core processing of the `local_sigm_times_exp` optimization.
This recursive function operates on a multiplication tree as output by
`parse_mul_tree`. It walks through the tree and modifies it in-place
by replacing matching pairs (exp, sigmoid) with the desired optimized
version.
Parameters
----------
tree
The sub-tree to operate on.
exp_x
List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
exp_minus_x
Similar to `exp_x`, but for `exp(-x)`.
sigm_x
Similar to `exp_x`, but for `sigmoid(x)`.
sigm_minus_x
Similar to `exp_x`, but for `sigmoid(-x)`.
parent
Parent of `tree` (None if `tree` is the global root).
child_idx
Index of `tree` in its parent's inputs (None if `tree` is the global
root).
full_tree
The global multiplication tree (should not be set except by recursive
calls to this function). Used for debugging only.
Returns
-------
bool
True if a modification was performed somewhere in the whole multiplication
tree, or False otherwise.
"""
if exp_x is None:
exp_x = []
if exp_minus_x is None:
exp_minus_x = []
if sigm_x is None:
sigm_x = []
if sigm_minus_x is None:
sigm_minus_x = []
if full_tree is None:
full_tree = tree
if False: # Debug code.
print("<perform_sigm_times_exp>")
print(f" full_tree = {full_tree}")
print(f" tree = {tree}")
print(f" exp_x = {exp_x}")
print(f" exp_minus_x = {exp_minus_x}")
print(f" sigm_x = {sigm_x}")
print(f" sigm_minus_x= {sigm_minus_x}")
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs of the multiplication.
rval = False
for sub_idx, sub_tree in enumerate(inputs):
rval |= perform_sigm_times_exp(
tree=sub_tree,
parent=tree,
child_idx=sub_idx,
exp_x=exp_x,
exp_minus_x=exp_minus_x,
sigm_x=sigm_x,
sigm_minus_x=sigm_minus_x,
full_tree=full_tree,
)
return rval
else:
# Reached a leaf: if it is an exponential or a sigmoid, then we
# first attempt to find a match in leaves already visited.
# If there is such a match, we modify the already-visited leaf
# accordingly: for instance if we visited a leaf sigmoid(x), then
# find later a -exp(-x), we replace the previous leaf by
# -sigmoid(-x) and remove the -exp(-x) from the tree.
# If no match is found, then we register this leaf so that it can
# be found later while walking the tree.
var = inputs
keep_it = False
exp_info = is_exp(var)
if exp_info is not None:
exp_neg, exp_arg = exp_info
neg ^= exp_neg
neg_arg = is_neg(exp_arg)
if neg_arg is None:
if not replace_leaf(exp_arg, sigm_minus_x, sigm_x, sigmoid, neg):
exp_x.append((exp_arg, tree))
keep_it = True
else:
if not replace_leaf(
neg_arg, sigm_x, sigm_minus_x, lambda x: sigmoid(-x), neg
):
exp_minus_x.append((neg_arg, tree))
keep_it = True
elif var.owner and var.owner.op == sigmoid:
sigm_arg = var.owner.inputs[0]
neg_arg = is_neg(sigm_arg)
if neg_arg is None:
if not replace_leaf(
sigm_arg, exp_minus_x, sigm_minus_x, lambda x: sigmoid(-x), neg
):
sigm_x.append((sigm_arg, tree))
keep_it = True
else:
if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg):
sigm_minus_x.append((neg_arg, tree))
keep_it = True
else:
# It is not an exponential nor a sigmoid.
keep_it = True
if not keep_it:
# Delete this leaf, i.e. replace it by [False, None] (corresponding
# to a multiplication by 1).
assert parent is not None
parent[1][child_idx] = [False, None]
return not keep_it
@register_stabilize
@local_optimizer([mul])
def local_sigm_times_exp(fgraph, node):
"""
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
todo: add stack traces to the intermediate variables
"""
# Bail early if it is not a multiplication.
if node.op != mul:
return None
# Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0])
# Perform core optimization.
did_something = perform_sigm_times_exp(mul_tree)
if not did_something:
# No change.
return None
# The optimization may have introduced multiplications by 1 in the tree:
# get rid of them.
mul_tree = simplify_mul(mul_tree)
# Recompute final output based on the updated tree.
out = compute_mul(mul_tree)
# keep the stack trace
copy_stack_trace(node.outputs[0], out)
return [out]
@register_stabilize
@local_optimizer([inv])
def local_inv_1_plus_exp(fgraph, node):
"""
1/(1+exp(x)) -> sigm(-x)
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
if node.op == inv:
inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = scalarconsts_rest(
inv_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
out = _fill_chain(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
scalar_inputs,
)
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): inv_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace([nonconsts[0], inv_arg, node.outputs[0]], out)
return out
# Registration is below, and conditional.
@local_optimizer([sub])
def local_1msigmoid(fgraph, node):
"""
1-sigm(x) -> sigm(-x)
"""
if node.op == sub:
sub_l, sub_r = node.inputs
if len(fgraph.clients[sub_r]) > 1:
return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = get_scalar_constant_value(sub_l)
except NotScalarConstantError:
return
if np.allclose(np.sum(val_l), 1):
out = sigmoid(-sub_r.owner.inputs[0])
copy_stack_trace([sub_r, node.outputs[0]], out)
return [out]
register_local_1msigmoid = False
# This is False because the Stabilize pattern above
# is looking for 1-sigm. Also AlgebraicCanonizer turns neg into *(-1) and so
# this optimization might set off an unwanted chain of things.
# OTH - this transformation can be seen as pushing normal arithmetic either below or above the
# sigmoidal nonlinearity... so if the canonicalized form had anything to say about that then it
# would be a consideration... anyway leaving False for now.
if register_local_1msigmoid:
register_canonicalize(local_1msigmoid)
...@@ -6,36 +6,16 @@ stability. ...@@ -6,36 +6,16 @@ stability.
""" """
import warnings
import numpy as np
import aesara import aesara
from aesara import printing from aesara import printing
from aesara import scalar as aes from aesara import scalar as aes
from aesara.configdefaults import config from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.graph.opt import PatternSub, copy_stack_trace, local_optimizer
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar import sigmoid as scalar_sigmoid from aesara.scalar import sigmoid as scalar_sigmoid
from aesara.tensor import basic_opt from aesara.tensor.basic import constant
from aesara.tensor.basic import constant, get_scalar_constant_value
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.math import clip, sigmoid
from aesara.tensor.math import ( from aesara.tensor.type import TensorType
add,
clip,
exp,
inv,
log,
log1p,
mul,
neg,
sigmoid,
softplus,
sub,
true_div,
)
from aesara.tensor.type import TensorType, values_eq_approx_remove_inf
class UltraFastScalarSigmoid(aes.UnaryScalarOp): class UltraFastScalarSigmoid(aes.UnaryScalarOp):
...@@ -188,662 +168,3 @@ def local_hard_sigmoid(fgraph, node): ...@@ -188,662 +168,3 @@ def local_hard_sigmoid(fgraph, node):
aesara.compile.optdb["uncanonicalize"].register( aesara.compile.optdb["uncanonicalize"].register(
"local_hard_sigmoid", local_hard_sigmoid "local_hard_sigmoid", local_hard_sigmoid
) )
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
if len(not_is_1) == 1:
return not_is_1[0]
logsigm_to_softplus = PatternSub(
(log, (sigmoid, "x")),
(neg, (softplus, (neg, "x"))),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
)
def _is_1(expr):
"""
Returns
-------
bool
True iff expr is a constant close to 1.
"""
try:
v = get_scalar_constant_value(expr)
return np.allclose(v, 1)
except NotScalarConstantError:
return False
log1msigm_to_softplus = PatternSub(
(log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))),
(neg, (softplus, "x")),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
)
log1pexp_to_softplus = PatternSub(
(log1p, (exp, "x")),
(softplus, "x"),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
)
log1p_neg_sigmoid = PatternSub(
(log1p, (neg, (sigmoid, "x"))),
(neg, (softplus, "x")),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True,
)
basic_opt.register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus")
basic_opt.register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus")
basic_opt.register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus")
basic_opt.register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid,")
def is_1pexp(t, only_process_constants=True):
"""
Returns
-------
object
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
"""
if t.owner and t.owner.op == add:
scalars, scalar_inputs, nonconsts = basic_opt.scalarconsts_rest(
t.owner.inputs, only_process_constants=only_process_constants
)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1:
maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == exp:
# Verify that the constant terms sum to 1.
if scalars:
scal_sum = scalars[0]
for s in scalars[1:]:
scal_sum = scal_sum + s
if np.allclose(scal_sum, 1):
return False, maybe_exp.owner.inputs[0]
# Before 7987b51 there used to be a bug where *any* constant
# was considered as if it was equal to 1, and thus this
# function would incorrectly identify it as (1 + exp(x)).
if config.warn__identify_1pexp_bug:
warnings.warn(
"Although your current code is fine, please note that "
"Aesara versions prior to 0.5 (more specifically, "
"prior to commit 7987b51 on 2011-12-18) may have "
"yielded an incorrect result. To remove this warning, "
"either set the `warn__identify_1pexp_bug` config "
"option to False, or `warn__ignore_bug_before` to at "
"least '0.4.1'."
)
return None
def is_exp(var):
"""
Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
Parameters
----------
var
The Variable to analyze.
Returns
-------
tuple
A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var`
cannot be cast into either form, then return `None`.
"""
_neg = False
neg_info = is_neg(var)
if neg_info is not None:
_neg = True
var = neg_info
if var.owner and var.owner.op == exp:
return _neg, var.owner.inputs[0]
def is_mul(var):
"""
Match a variable with `x * y * z * ...`.
Parameters
----------
var
The Variable to analyze.
Returns
-------
object
A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
"""
if var.owner and var.owner.op == mul:
return var.owner.inputs
else:
return None
def partition_num_or_denom(r, f):
if r.owner and r.owner.op == mul:
a = r.owner.inputs
else:
a = [r]
# ugly 2.4-compatible thing
f_terms = []
_neg = False
rest = []
for t in a:
f_t = f(t)
if f_t is None:
rest.append(t)
else:
neg_t, f_t = f_t
f_terms.append(f_t)
_neg ^= neg_t # bit flip if neg_t is true
return f_terms, rest, _neg
def is_neg(var):
"""
Match a variable with the `-x` pattern.
Parameters
----------
var
The Variable to analyze.
Returns
-------
object
`x` if `var` is of the form `-x`, or None otherwise.
"""
var_node = var.owner
if not var_node:
return None
# First match against `neg`.
if var_node.op == neg:
return var_node.inputs[0]
# Then match against a multiplication by -1.
if var_node.op == mul and len(var_node.inputs) >= 2:
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
# Found a multiplication by -1.
if len(var_node.inputs) == 2:
# Only return the other input.
return var_node.inputs[1 - idx]
else:
# Return the multiplication of all other inputs.
return mul(*(var_node.inputs[0:idx] + var_node.inputs[idx + 1 :]))
# No match.
return None
@basic_opt.register_stabilize
@local_optimizer([true_div])
def local_exp_over_1_plus_exp(fgraph, node):
"""
exp(x)/(1+exp(x)) -> sigm(x)
c/(1+exp(x)) -> c*sigm(-x)
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
if node.op == true_div:
# find all the exp() terms in the numerator
num, denom = node.inputs
num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp)
sigmoids = []
for t in denom_1pexp:
if t in num_exp_x:
# case: exp(x) /(1+exp(x))
sigmoids.append(sigmoid(t))
del num_exp_x[num_exp_x.index(t)]
else:
# case: 1/(1+exp(x))
sigmoids.append(sigmoid(-t))
copy_stack_trace(node.outputs[0], sigmoids[-1])
if not sigmoids: # we didn't find any. abort
return
# put the new numerator together
new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest
if len(new_num) == 1:
new_num = new_num[0]
else:
new_num = mul(*new_num)
if num_neg ^ denom_neg:
new_num = -new_num
copy_stack_trace(num, new_num)
if len(denom_rest) == 0:
return [new_num]
elif len(denom_rest) == 1:
out = new_num / denom_rest[0]
else:
out = new_num / mul(*denom_rest)
copy_stack_trace(node.outputs[0], out)
return [out]
def parse_mul_tree(root):
"""
Parse a tree of multiplications starting at the given root.
Parameters
----------
root
The variable at the root of the tree.
Returns
-------
object
A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs.
Each input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Examples
--------
x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]]
"""
# Is it a multiplication?
mul_info = is_mul(root)
if mul_info is None:
# Is it a negation?
neg_info = is_neg(root)
if neg_info is None:
# Keep the root "as is".
return [False, root]
else:
# Recurse, inverting the negation.
neg, sub_tree = parse_mul_tree(neg_info)
return [not neg, sub_tree]
else:
# Recurse into inputs.
return [False, list(map(parse_mul_tree, mul_info))]
def replace_leaf(arg, leaves, new_leaves, op, neg):
"""
Attempt to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`.
Parameters
----------
arg
The argument of the leaf we are looking for.
leaves
List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
new_leaves
If a replacement occurred, then the leaf is removed from `leaves`
and added to the list `new_leaves` (after being modified by `op`).
op
A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
neg : bool
If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
Returns
-------
bool
True if a replacement occurred, or False otherwise.
"""
for idx, x in enumerate(leaves):
if x[0] == arg:
x[1][0] ^= neg
x[1][1] = op(arg)
leaves.pop(idx)
new_leaves.append(x)
return True
return False
def simplify_mul(tree):
"""
Simplify a multiplication tree.
Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
object
A multiplication tree computing the same output as `tree` but without
useless multiplications by 1 nor -1 (identified by leaves of the form
[False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
"""
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs.
s_inputs = []
for s_i in map(simplify_mul, inputs):
if s_i[1] is None:
# Multiplication by +/-1.
neg ^= s_i[0]
else:
s_inputs.append(s_i)
if not s_inputs:
# The multiplication is empty.
rval = [neg, None]
elif len(s_inputs) == 1:
# The multiplication has a single input.
s_inputs[0][0] ^= neg
rval = s_inputs[0]
else:
rval = [neg, s_inputs]
else:
rval = tree
# print 'simplify_mul: %s -> %s' % (tree, rval)
return rval
def compute_mul(tree):
"""
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree.
Parameters
----------
tree
A multiplication tree (as output by `parse_mul_tree`).
Returns
-------
object
A Variable that computes the multiplication represented by the tree.
"""
neg, inputs = tree
if inputs is None:
raise AssertionError(
"Function `compute_mul` found a missing leaf, did you forget to "
"call `simplify_mul` on the tree first?"
)
elif isinstance(inputs, list):
# Recurse through inputs.
rval = mul(*list(map(compute_mul, inputs)))
else:
rval = inputs
if neg:
rval = -rval
return rval
def perform_sigm_times_exp(
tree,
exp_x=None,
exp_minus_x=None,
sigm_x=None,
sigm_minus_x=None,
parent=None,
child_idx=None,
full_tree=None,
):
"""
Core processing of the `local_sigm_times_exp` optimization.
This recursive function operates on a multiplication tree as output by
`parse_mul_tree`. It walks through the tree and modifies it in-place
by replacing matching pairs (exp, sigmoid) with the desired optimized
version.
Parameters
----------
tree
The sub-tree to operate on.
exp_x
List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
exp_minus_x
Similar to `exp_x`, but for `exp(-x)`.
sigm_x
Similar to `exp_x`, but for `sigmoid(x)`.
sigm_minus_x
Similar to `exp_x`, but for `sigmoid(-x)`.
parent
Parent of `tree` (None if `tree` is the global root).
child_idx
Index of `tree` in its parent's inputs (None if `tree` is the global
root).
full_tree
The global multiplication tree (should not be set except by recursive
calls to this function). Used for debugging only.
Returns
-------
bool
True if a modification was performed somewhere in the whole multiplication
tree, or False otherwise.
"""
if exp_x is None:
exp_x = []
if exp_minus_x is None:
exp_minus_x = []
if sigm_x is None:
sigm_x = []
if sigm_minus_x is None:
sigm_minus_x = []
if full_tree is None:
full_tree = tree
if False: # Debug code.
print("<perform_sigm_times_exp>")
print(f" full_tree = {full_tree}")
print(f" tree = {tree}")
print(f" exp_x = {exp_x}")
print(f" exp_minus_x = {exp_minus_x}")
print(f" sigm_x = {sigm_x}")
print(f" sigm_minus_x= {sigm_minus_x}")
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs of the multiplication.
rval = False
for sub_idx, sub_tree in enumerate(inputs):
rval |= perform_sigm_times_exp(
tree=sub_tree,
parent=tree,
child_idx=sub_idx,
exp_x=exp_x,
exp_minus_x=exp_minus_x,
sigm_x=sigm_x,
sigm_minus_x=sigm_minus_x,
full_tree=full_tree,
)
return rval
else:
# Reached a leaf: if it is an exponential or a sigmoid, then we
# first attempt to find a match in leaves already visited.
# If there is such a match, we modify the already-visited leaf
# accordingly: for instance if we visited a leaf sigmoid(x), then
# find later a -exp(-x), we replace the previous leaf by
# -sigmoid(-x) and remove the -exp(-x) from the tree.
# If no match is found, then we register this leaf so that it can
# be found later while walking the tree.
var = inputs
keep_it = False
exp_info = is_exp(var)
if exp_info is not None:
exp_neg, exp_arg = exp_info
neg ^= exp_neg
neg_arg = is_neg(exp_arg)
if neg_arg is None:
if not replace_leaf(exp_arg, sigm_minus_x, sigm_x, sigmoid, neg):
exp_x.append((exp_arg, tree))
keep_it = True
else:
if not replace_leaf(
neg_arg, sigm_x, sigm_minus_x, lambda x: sigmoid(-x), neg
):
exp_minus_x.append((neg_arg, tree))
keep_it = True
elif var.owner and var.owner.op == sigmoid:
sigm_arg = var.owner.inputs[0]
neg_arg = is_neg(sigm_arg)
if neg_arg is None:
if not replace_leaf(
sigm_arg, exp_minus_x, sigm_minus_x, lambda x: sigmoid(-x), neg
):
sigm_x.append((sigm_arg, tree))
keep_it = True
else:
if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg):
sigm_minus_x.append((neg_arg, tree))
keep_it = True
else:
# It is not an exponential nor a sigmoid.
keep_it = True
if not keep_it:
# Delete this leaf, i.e. replace it by [False, None] (corresponding
# to a multiplication by 1).
assert parent is not None
parent[1][child_idx] = [False, None]
return not keep_it
@basic_opt.register_stabilize
@local_optimizer([mul])
def local_sigm_times_exp(fgraph, node):
"""
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
todo: add stack traces to the intermediate variables
"""
# Bail early if it is not a multiplication.
if node.op != mul:
return None
# Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0])
# Perform core optimization.
did_something = perform_sigm_times_exp(mul_tree)
if not did_something:
# No change.
return None
# The optimization may have introduced multiplications by 1 in the tree:
# get rid of them.
mul_tree = simplify_mul(mul_tree)
# Recompute final output based on the updated tree.
out = compute_mul(mul_tree)
# keep the stack trace
copy_stack_trace(node.outputs[0], out)
return [out]
@basic_opt.register_stabilize
@local_optimizer([inv])
def local_inv_1_plus_exp(fgraph, node):
"""
1/(1+exp(x)) -> sigm(-x)
"""
# this optimization should be done for numerical stability
# so we don't care to check client counts
if node.op == inv:
inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == add:
scalars_, scalar_inputs, nonconsts = basic_opt.scalarconsts_rest(
inv_arg.owner.inputs, only_process_constants=True
)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1):
out = basic_opt._fill_chain(
sigmoid(neg(nonconsts[0].owner.inputs[0])),
scalar_inputs,
)
# keep combined stack traces of
# exp(x): nonconsts[0],
# 1 + exp(x): inv_arg,
# 1 / (1 + exp(x)): node.outputs[0]
copy_stack_trace([nonconsts[0], inv_arg, node.outputs[0]], out)
return out
# Registration is below, and conditional.
@local_optimizer([sub])
def local_1msigmoid(fgraph, node):
"""
1-sigm(x) -> sigm(-x)
"""
if node.op == sub:
sub_l, sub_r = node.inputs
if len(fgraph.clients[sub_r]) > 1:
return # graph is using both sigm and 1-sigm
if sub_r.owner and sub_r.owner.op == sigmoid:
try:
val_l = get_scalar_constant_value(sub_l)
except NotScalarConstantError:
return
if np.allclose(np.sum(val_l), 1):
out = sigmoid(-sub_r.owner.inputs[0])
copy_stack_trace([sub_r, node.outputs[0]], out)
return [out]
register_local_1msigmoid = False
# This is False because the Stabilize pattern above
# is looking for 1-sigm. Also AlgebraicCanonizer turns neg into *(-1) and so
# this optimization might set off an unwanted chain of things.
# OTH - this transformation can be seen as pushing normal arithmetic either below or above the
# sigmoidal nonlinearity... so if the canonicalized form had anything to say about that then it
# would be a consideration... anyway leaving False for now.
if register_local_1msigmoid:
basic_opt.register_canonicalize(local_1msigmoid)
...@@ -9,16 +9,15 @@ from aesara.scalar import Softplus ...@@ -9,16 +9,15 @@ from aesara.scalar import Softplus
from aesara.tensor import sigmoid, softplus from aesara.tensor import sigmoid, softplus
from aesara.tensor.inplace import neg_inplace, sigmoid_inplace from aesara.tensor.inplace import neg_inplace, sigmoid_inplace
from aesara.tensor.math import clip, exp, log, mul, neg from aesara.tensor.math import clip, exp, log, mul, neg
from aesara.tensor.nnet.sigm import ( from aesara.tensor.math_opt import (
compute_mul, compute_mul,
hard_sigmoid,
is_1pexp, is_1pexp,
parse_mul_tree, parse_mul_tree,
perform_sigm_times_exp, perform_sigm_times_exp,
register_local_1msigmoid, register_local_1msigmoid,
simplify_mul, simplify_mul,
ultra_fast_sigmoid,
) )
from aesara.tensor.nnet.sigm import hard_sigmoid, ultra_fast_sigmoid
from aesara.tensor.shape import Reshape from aesara.tensor.shape import Reshape
from aesara.tensor.type import fmatrix, matrix, scalar, vector, vectors from aesara.tensor.type import fmatrix, matrix, scalar, vector, vectors
from tests import unittest_tools as utt from tests import unittest_tools as utt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论