提交 f40a8a25 authored 作者: nouiz's avatar nouiz

Merge pull request #233 from delallea/sigm_opt_fix

Fixed optimization for exp(x) * sigmoid(-x)
......@@ -2,6 +2,9 @@
These functions implement special cases of exp and log to improve numerical stability.
"""
from itertools import imap
import numpy
from theano import gof
......@@ -115,6 +118,7 @@ logsigm_to_softplus = gof.PatternSub(
allow_multiple_clients = True,
skip_identities_fn=_skip_mul_1)
def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
......@@ -156,15 +160,40 @@ def is_1pexp(t):
return False, maybe_exp.owner.inputs[0]
return None
def is_exp(t):
# if t is of form (exp(x)) then return x
# else return None
def is_exp(var):
"""
Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
:param var: The Variable to analyze.
:return: A pair (b, x) with `b` a boolean set to True if `var` is of the
form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` cannot
be cast into either form, then return `None`.
"""
neg = False
if t.owner and t.owner.op == tensor.neg:
t = t.owner.inputs[0]
neg_info = is_neg(var)
if neg_info is not None:
neg = True
if t.owner and t.owner.op == tensor.exp:
return neg, t.owner.inputs[0]
var = neg_info
if var.owner and var.owner.op == tensor.exp:
return neg, var.owner.inputs[0]
def is_mul(var):
"""
Match a variable with `x * y * z * ...`.
:param var: The Variable to analyze.
:return: A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
or None if `var` cannot be cast into this form.
"""
if var.owner and var.owner.op == tensor.mul:
return var.owner.inputs
else:
return None
def partition_num_or_denom(r, f):
if r.owner and r.owner.op == tensor.mul:
......@@ -187,6 +216,41 @@ def partition_num_or_denom(r, f):
return f_terms, rest, neg
def is_neg(var):
"""
Match a variable with the `-x` pattern.
:param var: The Variable to analyze.
:return: `x` if `var` is of the form `-x`, or None otherwise.
"""
apply = var.owner
if not apply:
return None
# First match against `tensor.neg`.
if apply.op == tensor.neg:
return apply.inputs[0]
# Then match against a multiplication by -1.
if apply.op == tensor.mul and len(apply.inputs) >= 2:
for idx, mul_input in enumerate(apply.inputs):
try:
constant = opt.get_constant_value(mul_input)
is_minus_1 = numpy.allclose(constant, -1)
except TypeError:
is_minus_1 = False
if is_minus_1:
# Found a multiplication by -1.
if len(apply.inputs) == 2:
# Only return the other input.
return apply.inputs[1 - idx]
else:
# Return the multiplication of all other inputs.
return tensor.mul(*(apply.inputs[0:idx] +
apply.inputs[idx + 1:]))
# No match.
return None
@opt.register_stabilize
@gof.local_optimizer([tensor.true_div])
def local_exp_over_1_plus_exp(node):
......@@ -231,71 +295,277 @@ def local_exp_over_1_plus_exp(node):
else:
return [new_num / tensor.mul(*denom_rest)]
@opt.register_stabilize
@gof.local_optimizer([tensor.mul])
def local_sigm_times_exp(node):
def parse_mul_tree(root):
"""
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
Parse a tree of multiplications starting at the given root.
:param root: The variable at the root of the tree.
:return: A tree where each non-leaf node corresponds to a multiplication
in the computation of `root`, represented by the list of its inputs. Each
input is a pair [n, x] with `n` a boolean value indicating whether
sub-tree `x` should be negated.
Examples:
x * y -> [False, [[False, x], [False, y]]]
-(x * y) -> [True, [[False, x], [False, y]]]
-x * y -> [False, [[True, x], [False, y]]]
-x -> [True, x]
(x * y) * -z -> [False, [[False, [[False, x], [False, y]]],
[True, z]]]
"""
# Is it a multiplication?
mul_info = is_mul(root)
if mul_info is None:
# Is it a negation?
neg_info = is_neg(root)
if neg_info is None:
# Keep the root "as is".
return [False, root]
else:
# Recurse, inverting the negation.
neg, sub_tree = parse_mul_tree(neg_info)
return [not neg, sub_tree]
else:
# Recurse into inputs.
return [False, map(parse_mul_tree, mul_info)]
def replace_leaf(arg, leaves, new_leaves, op, neg):
"""
Attempts to replace a leaf of a multiplication tree.
We search for a leaf in `leaves` whose argument is `arg`, and if we find
one, we remove it from `leaves` and add to `new_leaves` a leaf with
argument `arg` and variable `op(arg)`.
:param arg: The argument of the leaf we are looking for.
:param leaves: List of leaves to look into. Each leaf should be a pair
(x, l) with `x` the argument of the Op found in the leaf, and `l` the
actual leaf as found in a multiplication tree output by `parse_mul_tree`
(i.e. a pair [boolean, variable]).
:param new_leaves: If a replacement occurred, then the leaf is removed from
`leaves` and added to the list `new_leaves` (after being modified by `op`).
:param op: A function that, when applied to `arg`, returns the Variable
we want to replace the original leaf variable with.
:param neg: If True, then the boolean value associated to the leaf should
be swapped. If False, then this value should remain unchanged.
:return: True if a replacement occurred, or False otherwise.
"""
for idx, x in enumerate(leaves):
if x[0] == arg:
x[1][0] ^= neg
x[1][1] = op(arg)
leaves.pop(idx)
new_leaves.append(x)
return True
return False
def simplify_mul(tree):
"""
Simplify a multiplication tree.
:param tree: A multiplication tree (as output by `parse_mul_tree`).
:return: A multiplication tree computing the same output as `tree` but
without useless multiplications by 1 nor -1 (identified by leaves of the
form [False, None] or [True, None] respectively). Useless multiplications
(with less than two inputs) are also removed from the tree.
"""
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs.
s_inputs = []
for s_i in imap(simplify_mul, inputs):
if s_i[1] is None:
# Multiplication by +/-1.
neg ^= s_i[0]
else:
s_inputs.append(s_i)
if not s_inputs:
# The multiplication is empty.
rval = [neg, None]
elif len(s_inputs) == 1:
# The multiplication has a single input.
s_inputs[0][0] ^= neg
rval = s_inputs[0]
else:
rval = [neg, s_inputs]
else:
rval = tree
#print 'simplify_mul: %s -> %s' % (tree, rval)
return rval
def compute_mul(tree):
"""
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by `parse_mul_tree`, i.e.
compute_mul(parse_mul_tree(tree)) == tree
:param tree: A multiplication tree (as output by `parse_mul_tree`).
:return: A Variable that computes the multiplication represented by the
tree.
"""
neg, inputs = tree
if inputs is None:
raise AssertionError(
'Function `compute_mul` found a missing leaf, did you forget to '
'call `simplify_mul` on the tree first?')
elif isinstance(inputs, list):
# Recurse through inputs.
rval = tensor.mul(*map(compute_mul, inputs))
else:
rval = inputs
if neg:
rval = -rval
return rval
def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None,
sigm_minus_x=None, parent=None, child_idx=None,
full_tree=None):
"""
Core processing of the `local_sigm_times_exp` optimization.
This recursive function operates on a multiplication tree as output by
`parse_mul_tree`. It walks through the tree and modifies it in-place
by replacing matching pairs (exp, sigmoid) with the desired optimized
version.
:param tree: The sub-tree to operate on.
:exp_x: List of arguments x so that `exp(x)` exists somewhere in the whole
multiplication tree. Each argument is a pair (x, leaf) with `x` the
argument of the exponential, and `leaf` the corresponding leaf in the
multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
If None, this argument is initialized to an empty list.
:param exp_minus_x: Similar to `exp_x`, but for `exp(-x)`.
:param sigm_x: Similar to `exp_x`, but for `sigmoid(x)`.
:param sigm_minus_x: Similar to `exp_x`, but for `sigmoid(-x)`.
:param parent: Parent of `tree` (None if `tree` is the global root).
:param child_idx: Index of `tree` in its parent's inputs (None if `tree` is
the global root).
:param full_tree: The global multiplication tree (should not be set except
by recursive calls to this function). Used for debugging only.
:return: True if a modification was performed somewhere in the whole
multiplication tree, or False otherwise.
"""
# this is a numerical stability thing, so we dont check clients
if node.op == tensor.mul:
if exp_x is None:
exp_x = []
if exp_minus_x is None:
exp_minus_x = []
if sigm_x is None:
sigm_x = []
if sigm_minus_x is None:
sigm_minus_x = []
other = []
neg = False
for i in node.inputs:
while i.owner and i.owner.op == tensor.neg:
neg ^= True
i = i.owner.inputs[0]
if i.owner and i.owner.op == tensor.exp:
exp_arg = i.owner.inputs[0]
if exp_arg.owner and exp_arg.owner.op == tensor.neg:
exp_minus_x.append(exp_arg.owner.inputs[0])
else:
exp_x.append(exp_arg)
elif i.owner and i.owner.op == sigmoid:
sigm_arg = i.owner.inputs[0]
if sigm_arg.owner and sigm_arg.owner.op == tensor.neg:
sigm_minus_x.append(sigm_arg.owner.inputs[0])
else:
sigm_x.append(sigm_arg)
else:
other.append(i)
# remove matched pairs in exp_x and sigm_minus_x
did_something = False
for i in exp_x:
if i in sigm_minus_x:
del sigm_minus_x[sigm_minus_x.index(i)]
other.append(sigmoid(i))
did_something = True
else:
other.append(i)
# remove matched pairs in exp_minus_x and sigm_x
for i in exp_minus_x:
if i in sigm_x:
del sigm_x[sigm_x.index(i)]
other.append(sigm(-i))
did_something = True
if full_tree is None:
full_tree = tree
if False: # Debug code.
print '<perform_sigm_times_exp>'
print ' full_tree = %s' % full_tree
print ' tree = %s' % tree
print ' exp_x = %s' % exp_x
print ' exp_minus_x = %s' % exp_minus_x
print ' sigm_x = %s' % sigm_x
print ' sigm_minus_x= %s' % sigm_minus_x
neg, inputs = tree
if isinstance(inputs, list):
# Recurse through inputs of the multiplication.
rval = False
for sub_idx, sub_tree in enumerate(inputs):
rval |= perform_sigm_times_exp(
tree=sub_tree, parent=tree, child_idx=sub_idx,
exp_x=exp_x, exp_minus_x=exp_minus_x, sigm_x=sigm_x,
sigm_minus_x=sigm_minus_x, full_tree=full_tree)
return rval
else:
# Reached a leaf: if it is an exponential or a sigmoid, then we
# first attempt to find a match in leaves already visited.
# If there is such a match, we modify the already-visited leaf
# accordingly: for instance if we visited a leaf sigmoid(x), then
# find later a -exp(-x), we replace the previous leaf by
# -sigmoid(-x) and remove the -exp(-x) from the tree.
# If no match is found, then we register this leaf so that it can
# be found later while walking the tree.
var = inputs
keep_it = False
exp_info = is_exp(var)
if exp_info is not None:
exp_neg, exp_arg = exp_info
neg ^= exp_neg
neg_arg = is_neg(exp_arg)
if neg_arg is None:
if not replace_leaf(exp_arg, sigm_minus_x, sigm_x,
sigmoid, neg):
exp_x.append((exp_arg, tree))
keep_it = True
else:
other.append(i)
if did_something:
terms = (other +
[sigmoid(x) for x in sigm_x] +
[sigmoid(-x) for x in sigm_minus_x])
if len(terms) > 1:
rval = tensor.mul(*terms)
if not replace_leaf(neg_arg, sigm_x, sigm_minus_x,
lambda x: sigmoid(-x), neg):
exp_minus_x.append((neg_arg, tree))
keep_it = True
elif var.owner and var.owner.op == sigmoid:
sigm_arg = var.owner.inputs[0]
neg_arg = is_neg(sigm_arg)
if neg_arg is None:
if not replace_leaf(sigm_arg, exp_minus_x, sigm_minus_x,
lambda x: sigmoid(-x), neg):
sigm_x.append((sigm_arg, tree))
keep_it = True
else:
rval = terms[0]
if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg):
sigm_minus_x.append((neg_arg, tree))
keep_it = True
else:
# It is not an exponential nor a sigmoid.
keep_it = True
if not keep_it:
# Delete this leaf, i.e. replace it by [False, None] (corresponding
# to a multiplication by 1).
assert parent is not None
parent[1][child_idx] = [False, None]
return not keep_it
if neg:
return [-rval]
else:
return [rval]
@opt.register_stabilize
@gof.local_optimizer([tensor.mul])
def local_sigm_times_exp(node):
"""
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
"""
# Bail early if it is not a multiplication.
if node.op != tensor.mul:
return None
# Obtain tree of multiplications starting at this node.
mul_tree = parse_mul_tree(node.outputs[0])
# Perform core optimization.
did_something = perform_sigm_times_exp(mul_tree)
if not did_something:
# No change.
return None
# The optimization may have introduced multiplications by 1 in the tree:
# get rid of them.
mul_tree = simplify_mul(mul_tree)
# Recompute final output based on the updated tree.
return [compute_mul(mul_tree)]
@opt.register_stabilize
......@@ -391,4 +661,3 @@ if 0:
return []
else:
return [node.outputs[0]]
......@@ -7,7 +7,9 @@ from theano import tensor as T
from theano import config
from theano.tests import unittest_tools as utt
from theano.tensor.nnet import sigmoid, sigmoid_inplace, softplus, tensor
from theano.tensor.nnet.sigm import register_local_1msigmoid
from theano.tensor.nnet.sigm import (
compute_mul, parse_mul_tree, perform_sigm_times_exp,
register_local_1msigmoid, simplify_mul)
class T_sigmoid(unittest.TestCase):
......@@ -23,12 +25,29 @@ class T_softplus(unittest.TestCase):
utt.verify_grad(softplus, [numpy.random.rand(3,4)])
class T_sigmoid_opts(unittest.TestCase):
def test_exp_over_1_plus_exp(self):
def get_mode(self, excluding=[]):
"""
Return appropriate mode for the tests.
:param excluding: List of optimizations to exclude.
:return: The current default mode unless the `config.mode` option is
set to 'FAST_COMPILE' (in which case it is replaced by the 'FAST_RUN'
mode), without the optimizations specified in `excluding`.
"""
m = theano.config.mode
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
m = theano.compile.mode.get_mode(m)
m = m.excluding('local_elemwise_fusion')
mode = theano.compile.mode.get_mode('FAST_RUN')
else:
mode = theano.compile.mode.get_default_mode()
if excluding:
return mode.excluding(*excluding)
else:
return mode
def test_exp_over_1_plus_exp(self):
m = self.get_mode(excluding=['local_elemwise_fusion'])
x = T.dvector()
......@@ -60,10 +79,7 @@ class T_sigmoid_opts(unittest.TestCase):
if not register_local_1msigmoid:
return
m = theano.config.mode
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
m = self.get_mode()
x = T.fmatrix()
# tests exp_over_1_plus_exp
......@@ -77,6 +93,80 @@ class T_sigmoid_opts(unittest.TestCase):
assert [node.op for node in f.maker.env.toposort()] == [tensor.neg,
sigmoid_inplace]
def test_local_sigm_times_exp(self):
"""
Test the `local_sigm_times_exp` optimization.
exp(x) * sigm(-x) -> sigm(x)
exp(-x) * sigm(x) -> sigm(-x)
"""
def match(func, ops):
#print [node.op.scalar_op for node in func.maker.env.toposort()]
assert [node.op for node in func.maker.env.toposort()] == ops
m = self.get_mode(excluding=['local_elemwise_fusion', 'inplace'])
x, y = tensor.vectors('x', 'y')
f = theano.function([x], sigmoid(-x) * tensor.exp(x), mode=m)
theano.printing.debugprint(f)
match(f, [sigmoid])
f = theano.function([x], sigmoid(x) * tensor.exp(-x), mode=m)
theano.printing.debugprint(f)
match(f, [tensor.neg, sigmoid])
f = theano.function([x], -(-(-(sigmoid(x)))) * tensor.exp(-x), mode=m)
theano.printing.debugprint(f)
match(f, [tensor.neg, sigmoid, tensor.neg])
f = theano.function(
[x, y],
(sigmoid(x) * sigmoid(-y) * -tensor.exp(-x) * tensor.exp(x * y) *
tensor.exp(y)),
mode=m)
theano.printing.debugprint(f)
match(f, [sigmoid, tensor.mul, tensor.neg, tensor.exp, sigmoid,
tensor.mul, tensor.neg])
def test_perform_sigm_times_exp(self):
"""
Test the core function doing the `sigm_times_exp` optimization.
It is easier to test different graph scenarios this way than by
compiling a theano function.
"""
x, y, z, t = tensor.vectors('x', 'y', 'z', 't')
exp = tensor.exp
def ok(expr1, expr2):
trees = [parse_mul_tree(e) for e in (expr1, expr2)]
perform_sigm_times_exp(trees[0])
trees[0] = simplify_mul(trees[0])
# TODO Ideally we would do a full comparison without `str`. However
# the implementation of `__eq__` in Variables is not currently
# appropriate for this. So for now we use this limited technique,
# but it could be improved on.
good = str(trees[0]) == str(trees[1])
if not good:
print trees[0]
print trees[1]
print '***'
theano.printing.debugprint(compute_mul(trees[0]))
print '***'
theano.printing.debugprint(compute_mul(trees[1]))
assert good
ok(sigmoid(x) * exp(-x), sigmoid(-x))
ok(-x * sigmoid(x) * (y * (-1 * z) * exp(-x)),
-x * sigmoid(-x) * (y * (-1 * z)))
ok(-sigmoid(-x) *
(exp(y) * (-exp(-z) * 3 * -exp(x)) *
(y * 2 * (-sigmoid(-y) * (z + t) * exp(z)) * sigmoid(z))) *
-sigmoid(x),
sigmoid(x) *
(-sigmoid(y) * (-sigmoid(-z) * 3) * (y * 2 * ((z + t) * exp(z)))) *
-sigmoid(x))
ok(exp(-x) * -exp(-x) * (-sigmoid(x) * -sigmoid(x)),
-sigmoid(-x) * sigmoid(-x))
ok(-exp(x) * -sigmoid(-x) * -exp(-x),
-sigmoid(-x))
class T_softplus_opts(unittest.TestCase):
def setUp(self):
......@@ -123,3 +213,28 @@ class T_softplus_opts(unittest.TestCase):
assert len(topo)==1
assert isinstance(topo[0].op.scalar_op,theano.tensor.nnet.sigm.ScalarSoftplus)
f(numpy.random.rand(54).astype(config.floatX))
class T_sigmoid_utils(unittest.TestCase):
"""
Test utility functions found in 'sigm.py'.
"""
def test_compute_mul(self):
x, y, z = tensor.vectors('x', 'y', 'z')
tree = (x * y) * -z
mul_tree = parse_mul_tree(tree)
# Note that we do not test the reverse identity, i.e.
# compute_mul(parse_mul_tree(tree)) == tree
# because Theano currently lacks an easy way to compare variables.
assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree
def test_parse_mul_tree(self):
x, y, z = tensor.vectors('x', 'y', 'z')
assert parse_mul_tree(x * y) == [False, [[False, x], [False, y]]]
assert parse_mul_tree(-(x * y)) == [True, [[False, x], [False, y]]]
assert parse_mul_tree(-x * y) == [False, [[True, x], [False, y]]]
assert parse_mul_tree(-x) == [True, x]
assert parse_mul_tree((x * y) * -z) == [
False, [[False, [[False, x], [False, y]]], [True, z]]]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论