提交 f40a8a25 authored 作者: nouiz's avatar nouiz

Merge pull request #233 from delallea/sigm_opt_fix

Fixed optimization for exp(x) * sigmoid(-x)
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
These functions implement special cases of exp and log to improve numerical stability. These functions implement special cases of exp and log to improve numerical stability.
""" """
from itertools import imap
import numpy import numpy
from theano import gof from theano import gof
...@@ -115,6 +118,7 @@ logsigm_to_softplus = gof.PatternSub( ...@@ -115,6 +118,7 @@ logsigm_to_softplus = gof.PatternSub(
allow_multiple_clients = True, allow_multiple_clients = True,
skip_identities_fn=_skip_mul_1) skip_identities_fn=_skip_mul_1)
def _is_1(expr): def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """rtype bool. True iff expr is a constant close to 1
""" """
...@@ -156,15 +160,40 @@ def is_1pexp(t): ...@@ -156,15 +160,40 @@ def is_1pexp(t):
return False, maybe_exp.owner.inputs[0] return False, maybe_exp.owner.inputs[0]
return None return None
def is_exp(t):
# if t is of form (exp(x)) then return x def is_exp(var):
# else return None """
Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
:param var: The Variable to analyze.
:return: A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` cannot
be cast into either form, then return `None`.
"""
neg = False neg = False
if t.owner and t.owner.op == tensor.neg: neg_info = is_neg(var)
t = t.owner.inputs[0] if neg_info is not None:
neg = True neg = True
if t.owner and t.owner.op == tensor.exp: var = neg_info
return neg, t.owner.inputs[0] if var.owner and var.owner.op == tensor.exp:
return neg, var.owner.inputs[0]
def is_mul(var):
"""
Match a variable with `x * y * z * ...`.
:param var: The Variable to analyze.
:return: A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
"""
if var.owner and var.owner.op == tensor.mul:
return var.owner.inputs
else:
return None
def partition_num_or_denom(r, f): def partition_num_or_denom(r, f):
if r.owner and r.owner.op == tensor.mul: if r.owner and r.owner.op == tensor.mul:
...@@ -187,6 +216,41 @@ def partition_num_or_denom(r, f): ...@@ -187,6 +216,41 @@ def partition_num_or_denom(r, f):
return f_terms, rest, neg return f_terms, rest, neg
def is_neg(var):
"""
Match a variable with the `-x` pattern.
:param var: The Variable to analyze.
:return: `x` if `var` is of the form `-x`, or None otherwise.
"""
apply = var.owner
if not apply:
return None
# First match against `tensor.neg`.
if apply.op == tensor.neg:
return apply.inputs[0]
# Then match against a multiplication by -1.
if apply.op == tensor.mul and len(apply.inputs) >= 2:
for idx, mul_input in enumerate(apply.inputs):
try:
constant = opt.get_constant_value(mul_input)
is_minus_1 = numpy.allclose(constant, -1)
except TypeError:
is_minus_1 = False
if is_minus_1:
# Found a multiplication by -1.
if len(apply.inputs) == 2:
# Only return the other input.
return apply.inputs[1 - idx]
else:
# Return the multiplication of all other inputs.
return tensor.mul(*(apply.inputs[0:idx] +
apply.inputs[idx + 1:]))
# No match.
return None
@opt.register_stabilize @opt.register_stabilize
@gof.local_optimizer([tensor.true_div]) @gof.local_optimizer([tensor.true_div])
def local_exp_over_1_plus_exp(node): def local_exp_over_1_plus_exp(node):
...@@ -231,71 +295,277 @@ def local_exp_over_1_plus_exp(node): ...@@ -231,71 +295,277 @@ def local_exp_over_1_plus_exp(node):
else: else:
return [new_num / tensor.mul(*denom_rest)] return [new_num / tensor.mul(*denom_rest)]
@opt.register_stabilize def parse_mul_tree(root):
@gof.local_optimizer([tensor.mul])
def local_sigm_times_exp(node):
""" """
exp(x) * sigm(-x) -> sigm(x) Parse a tree of multiplications starting at the given root.
exp(-x) * sigm(x) -> sigm(-x)
:param root: The variable at the root of the tree.
:return: A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs. Each
input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Examples:
x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]]
""" """
# this is a numerical stability thing, so we dont check clients # Is it a multiplication?
if node.op == tensor.mul: mul_info = is_mul(root)
if mul_info is None:
# Is it a negation?
neg_info = is_neg(root)
if neg_info is None:
# Keep the root "as is".
return [False, root]
else:
# Recurse, inverting the negation.
neg, sub_tree = parse_mul_tree(neg_info)
return [not neg, sub_tree]
else:
# Recurse into inputs.
return [False, map(parse_mul_tree, mul_info)]
def replace_leaf(arg, leaves, new_leaves, op, neg):
"""
Attempts to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`.
:param arg: The argument of the leaf we are looking for.
:param leaves: List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
:param new_leaves: If a replacement occurred, then the leaf is removed from
`leaves` and added to the list `new_leaves` (after being modified by `op`).
:param op: A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
:param neg: If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
:return: True if a replacement occurred, or False otherwise.
"""
for idx, x in enumerate(leaves):
if x[0] == arg:
x[1][0] ^= neg
x[1][1] = op(arg)
leaves.pop(idx)
new_leaves.append(x)
return True
return False
def simplify_mul(tree):
"""
Simplify a multiplication tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`).
:return: A multiplication tree computing the same output as `tree` but
without useless multiplications by 1 nor -1 (identified by leaves of the
form [False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
"""
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs.
s_inputs = []
for s_i in imap(simplify_mul, inputs):
if s_i[1] is None:
# Multiplication by +/-1.
neg ^= s_i[0]
else:
s_inputs.append(s_i)
if not s_inputs:
# The multiplication is empty.
rval = [neg, None]
elif len(s_inputs) == 1:
# The multiplication has a single input.
s_inputs[0][0] ^= neg
rval = s_inputs[0]
else:
rval = [neg, s_inputs]
else:
rval = tree
#print 'simplify_mul: %s -> %s' % (tree, rval)
return rval
def compute_mul(tree):
"""
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree
:param tree: A multiplication tree (as output by `parse_mul_tree`).
:return: A Variable that computes the multiplication represented by the
tree.
"""
neg, inputs = tree
if inputs is None:
raise AssertionError(
'Function `compute_mul` found a missing leaf, did you forget to '
'call `simplify_mul` on the tree first?')
elif isinstance(inputs, list):
# Recurse through inputs.
rval = tensor.mul(*map(compute_mul, inputs))
else:
rval = inputs
if neg:
rval = -rval
return rval
def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None,
sigm_minus_x=None, parent=None, child_idx=None,
full_tree=None):
"""
Core processing of the `local_sigm_times_exp` optimization.
This recursive function operates on a multiplication tree as output by
`parse_mul_tree`. It walks through the tree and modifies it in-place
by replacing matching pairs (exp, sigmoid) with the desired optimized
version.
:param tree: The sub-tree to operate on.
:exp_x: List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
:param exp_minus_x: Similar to `exp_x`, but for `exp(-x)`.
:param sigm_x: Similar to `exp_x`, but for `sigmoid(x)`.
:param sigm_minus_x: Similar to `exp_x`, but for `sigmoid(-x)`.
:param parent: Parent of `tree` (None if `tree` is the global root).
:param child_idx: Index of `tree` in its parent's inputs (None if `tree` is
the global root).
:param full_tree: The global multiplication tree (should not be set except
by recursive calls to this function). Used for debugging only.
:return: True if a modification was performed somewhere in the whole
multiplication tree, or False otherwise.
"""
if exp_x is None:
exp_x = [] exp_x = []
if exp_minus_x is None:
exp_minus_x = [] exp_minus_x = []
if sigm_x is None:
sigm_x = [] sigm_x = []
if sigm_minus_x is None:
sigm_minus_x = [] sigm_minus_x = []
other = [] if full_tree is None:
neg = False full_tree = tree
for i in node.inputs: if False: # Debug code.
while i.owner and i.owner.op == tensor.neg: print '<perform_sigm_times_exp>'
neg ^= True print ' full_tree = %s' % full_tree
i = i.owner.inputs[0] print ' tree = %s' % tree
if i.owner and i.owner.op == tensor.exp: print ' exp_x = %s' % exp_x
exp_arg = i.owner.inputs[0] print ' exp_minus_x = %s' % exp_minus_x
if exp_arg.owner and exp_arg.owner.op == tensor.neg: print ' sigm_x = %s' % sigm_x
exp_minus_x.append(exp_arg.owner.inputs[0]) print ' sigm_minus_x= %s' % sigm_minus_x
else: neg, inputs = tree
exp_x.append(exp_arg) if isinstance(inputs, list):
elif i.owner and i.owner.op == sigmoid: # Recurse through inputs of the multiplication.
sigm_arg = i.owner.inputs[0] rval = False
if sigm_arg.owner and sigm_arg.owner.op == tensor.neg: for sub_idx, sub_tree in enumerate(inputs):
sigm_minus_x.append(sigm_arg.owner.inputs[0]) rval |= perform_sigm_times_exp(
else: tree=sub_tree, parent=tree, child_idx=sub_idx,
sigm_x.append(sigm_arg) exp_x=exp_x, exp_minus_x=exp_minus_x, sigm_x=sigm_x,
sigm_minus_x=sigm_minus_x, full_tree=full_tree)
return rval
else: else:
other.append(i) # Reached a leaf: if it is an exponential or a sigmoid, then we
# first attempt to find a match in leaves already visited.
# remove matched pairs in exp_x and sigm_minus_x # If there is such a match, we modify the already-visited leaf
did_something = False # accordingly: for instance if we visited a leaf sigmoid(x), then
for i in exp_x: # find later a -exp(-x), we replace the previous leaf by
if i in sigm_minus_x: # -sigmoid(-x) and remove the -exp(-x) from the tree.
del sigm_minus_x[sigm_minus_x.index(i)] # If no match is found, then we register this leaf so that it can
other.append(sigmoid(i)) # be found later while walking the tree.
did_something = True var = inputs
keep_it = False
exp_info = is_exp(var)
if exp_info is not None:
exp_neg, exp_arg = exp_info
neg ^= exp_neg
neg_arg = is_neg(exp_arg)
if neg_arg is None:
if not replace_leaf(exp_arg, sigm_minus_x, sigm_x,
sigmoid, neg):
exp_x.append((exp_arg, tree))
keep_it = True
else: else:
other.append(i) if not replace_leaf(neg_arg, sigm_x, sigm_minus_x,
lambda x: sigmoid(-x), neg):
# remove matched pairs in exp_minus_x and sigm_x exp_minus_x.append((neg_arg, tree))
for i in exp_minus_x: keep_it = True
if i in sigm_x: elif var.owner and var.owner.op == sigmoid:
del sigm_x[sigm_x.index(i)] sigm_arg = var.owner.inputs[0]
other.append(sigm(-i)) neg_arg = is_neg(sigm_arg)
did_something = True if neg_arg is None:
if not replace_leaf(sigm_arg, exp_minus_x, sigm_minus_x,
lambda x: sigmoid(-x), neg):
sigm_x.append((sigm_arg, tree))
keep_it = True
else: else:
other.append(i) if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg):
if did_something: sigm_minus_x.append((neg_arg, tree))
terms = (other + keep_it = True
[sigmoid(x) for x in sigm_x] +
[sigmoid(-x) for x in sigm_minus_x])
if len(terms) > 1:
rval = tensor.mul(*terms)
else: else:
rval = terms[0] # It is not an exponential nor a sigmoid.
keep_it = True
if not keep_it:
# Delete this leaf, i.e. replace it by [False, None] (corresponding
# to a multiplication by 1).
assert parent is not None
parent[1][child_idx] = [False, None]
return not keep_it
if neg:
return [-rval] @opt.register_stabilize
else: @gof.local_optimizer([tensor.mul])
return [rval] def local_sigm_times_exp(node):
"""
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
"""
# Bail early if it is not a multiplication.
if node.op != tensor.mul:
return None
# Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0])
# Perform core optimization.
did_something = perform_sigm_times_exp(mul_tree)
if not did_something:
# No change.
return None
# The optimization may have introduced multiplications by 1 in the tree:
# get rid of them.
mul_tree = simplify_mul(mul_tree)
# Recompute final output based on the updated tree.
return [compute_mul(mul_tree)]
@opt.register_stabilize @opt.register_stabilize
...@@ -391,4 +661,3 @@ if 0: ...@@ -391,4 +661,3 @@ if 0:
return [] return []
else: else:
return [node.outputs[0]] return [node.outputs[0]]
...@@ -7,7 +7,9 @@ from theano import tensor as T ...@@ -7,7 +7,9 @@ from theano import tensor as T
from theano import config from theano import config
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import sigmoid, sigmoid_inplace, softplus, tensor from theano.tensor.nnet import sigmoid, sigmoid_inplace, softplus, tensor
from theano.tensor.nnet.sigm import register_local_1msigmoid from theano.tensor.nnet.sigm import (
compute_mul, parse_mul_tree, perform_sigm_times_exp,
register_local_1msigmoid, simplify_mul)
class T_sigmoid(unittest.TestCase): class T_sigmoid(unittest.TestCase):
...@@ -23,12 +25,29 @@ class T_softplus(unittest.TestCase): ...@@ -23,12 +25,29 @@ class T_softplus(unittest.TestCase):
utt.verify_grad(softplus, [numpy.random.rand(3,4)]) utt.verify_grad(softplus, [numpy.random.rand(3,4)])
class T_sigmoid_opts(unittest.TestCase): class T_sigmoid_opts(unittest.TestCase):
def test_exp_over_1_plus_exp(self):
def get_mode(self, excluding=[]):
"""
Return appropriate mode for the tests.
:param excluding: List of optimizations to exclude.
:return: The current default mode unless the `config.mode` option is
set to 'FAST_COMPILE' (in which case it is replaced by the 'FAST_RUN'
mode), without the optimizations specified in `excluding`.
"""
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
m = 'FAST_RUN' mode = theano.compile.mode.get_mode('FAST_RUN')
m = theano.compile.mode.get_mode(m) else:
m = m.excluding('local_elemwise_fusion') mode = theano.compile.mode.get_default_mode()
if excluding:
return mode.excluding(*excluding)
else:
return mode
def test_exp_over_1_plus_exp(self):
m = self.get_mode(excluding=['local_elemwise_fusion'])
x = T.dvector() x = T.dvector()
...@@ -60,10 +79,7 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -60,10 +79,7 @@ class T_sigmoid_opts(unittest.TestCase):
if not register_local_1msigmoid: if not register_local_1msigmoid:
return return
m = theano.config.mode m = self.get_mode()
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
x = T.fmatrix() x = T.fmatrix()
# tests exp_over_1_plus_exp # tests exp_over_1_plus_exp
...@@ -77,6 +93,80 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -77,6 +93,80 @@ class T_sigmoid_opts(unittest.TestCase):
assert [node.op for node in f.maker.env.toposort()] == [tensor.neg, assert [node.op for node in f.maker.env.toposort()] == [tensor.neg,
sigmoid_inplace] sigmoid_inplace]
def test_local_sigm_times_exp(self):
"""
Test the `local_sigm_times_exp` optimization.
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
"""
def match(func, ops):
#print [node.op.scalar_op for node in func.maker.env.toposort()]
assert [node.op for node in func.maker.env.toposort()] == ops
m = self.get_mode(excluding=['local_elemwise_fusion', 'inplace'])
x, y = tensor.vectors('x', 'y')
f = theano.function([x], sigmoid(-x) * tensor.exp(x), mode=m)
theano.printing.debugprint(f)
match(f, [sigmoid])
f = theano.function([x], sigmoid(x) * tensor.exp(-x), mode=m)
theano.printing.debugprint(f)
match(f, [tensor.neg, sigmoid])
f = theano.function([x], -(-(-(sigmoid(x)))) * tensor.exp(-x), mode=m)
theano.printing.debugprint(f)
match(f, [tensor.neg, sigmoid, tensor.neg])
f = theano.function(
[x, y],
(sigmoid(x) * sigmoid(-y) * -tensor.exp(-x) * tensor.exp(x * y) *
tensor.exp(y)),
mode=m)
theano.printing.debugprint(f)
match(f, [sigmoid, tensor.mul, tensor.neg, tensor.exp, sigmoid,
tensor.mul, tensor.neg])
def test_perform_sigm_times_exp(self):
"""
Test the core function doing the `sigm_times_exp` optimization.
It is easier to test different graph scenarios this way than by
compiling a theano function.
"""
x, y, z, t = tensor.vectors('x', 'y', 'z', 't')
exp = tensor.exp
def ok(expr1, expr2):
trees = [parse_mul_tree(e) for e in (expr1, expr2)]
perform_sigm_times_exp(trees[0])
trees[0] = simplify_mul(trees[0])
# TODO Ideally we would do a full comparison without `str`. However
# the implementation of `__eq__` in Variables is not currently
# appropriate for this. So for now we use this limited technique,
# but it could be improved on.
good = str(trees[0]) == str(trees[1])
if not good:
print trees[0]
print trees[1]
print '***'
theano.printing.debugprint(compute_mul(trees[0]))
print '***'
theano.printing.debugprint(compute_mul(trees[1]))
assert good
ok(sigmoid(x) * exp(-x), sigmoid(-x))
ok(-x * sigmoid(x) * (y * (-1 * z) * exp(-x)),
-x * sigmoid(-x) * (y * (-1 * z)))
ok(-sigmoid(-x) *
(exp(y) * (-exp(-z) * 3 * -exp(x)) *
(y * 2 * (-sigmoid(-y) * (z + t) * exp(z)) * sigmoid(z))) *
-sigmoid(x),
sigmoid(x) *
(-sigmoid(y) * (-sigmoid(-z) * 3) * (y * 2 * ((z + t) * exp(z)))) *
-sigmoid(x))
ok(exp(-x) * -exp(-x) * (-sigmoid(x) * -sigmoid(x)),
-sigmoid(-x) * sigmoid(-x))
ok(-exp(x) * -sigmoid(-x) * -exp(-x),
-sigmoid(-x))
class T_softplus_opts(unittest.TestCase): class T_softplus_opts(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -123,3 +213,28 @@ class T_softplus_opts(unittest.TestCase): ...@@ -123,3 +213,28 @@ class T_softplus_opts(unittest.TestCase):
assert len(topo)==1 assert len(topo)==1
assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus) assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus)
f(numpy.random.rand(54).astype(config.floatX)) f(numpy.random.rand(54).astype(config.floatX))
class T_sigmoid_utils(unittest.TestCase):
"""
Test utility functions found in 'sigm.py'.
"""
def test_compute_mul(self):
x, y, z = tensor.vectors('x', 'y', 'z')
tree = (x * y) * -z
mul_tree = parse_mul_tree(tree)
# Note that we do not test the reverse identity, i.e.
# compute_mul(parse_mul_tree(tree)) == tree
# because Theano currently lacks an easy way to compare variables.
assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree
def test_parse_mul_tree(self):
x, y, z = tensor.vectors('x', 'y', 'z')
assert parse_mul_tree(x * y) == [False, [[False, x], [False, y]]]
assert parse_mul_tree(-(x * y)) == [True, [[False, x], [False, y]]]
assert parse_mul_tree(-x * y) == [False, [[True, x], [False, y]]]
assert parse_mul_tree(-x) == [True, x]
assert parse_mul_tree((x * y) * -z) == [
False, [[False, [[False, x], [False, y]]], [True, z]]]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论