提交 ca8d85de authored 作者: carriepl's avatar carriepl

Merge pull request #3290 from mohammadpz/prod_dimshuffle_opt

Prod dimshuffle opt
...@@ -4130,78 +4130,95 @@ def local_elemwise_sub_zeros(node): ...@@ -4130,78 +4130,95 @@ def local_elemwise_sub_zeros(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_div_dimshuffle(node): def local_sum_prod_div_dimshuffle(node):
""" """
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'. if dimension l of the DimShuffle is 'x'
or
prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
""" """
# TODO: extend it to product, and quotient of products
# It does not make much sense now to extend it to the case where the # It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the # dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation. # denominator would still be needed before the summation or production.
if isinstance(node.op, T.Sum): if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
axis = node.op.axis axis = node.op.axis
if axis is None: if axis is None:
axis = list(range(node.inputs[0].ndim)) axis = list(range(node.inputs[0].ndim))
# print 'axis =', axis node_input = node.inputs[0]
thing_summed = node.inputs[0] if node_input.owner and node_input.owner.op == T.true_div:
if thing_summed.owner and thing_summed.owner.op == T.true_div: numerator, denominator = node_input.owner.inputs
numerator, denominator = thing_summed.owner.inputs
# Old, bugged logic, reproduced here only to warn users # Old, bugged logic, reproduced here only to warn users
if config.warn.sum_div_dimshuffle_bug: if (config.warn.sum_div_dimshuffle_bug and
if numerator.owner and isinstance(numerator.owner.op, isinstance(node.op, T.Sum) and
T.DimShuffle): numerator.owner and
new_order = numerator.owner.op.new_order isinstance(numerator.owner.op, T.DimShuffle)):
compatible_dims = True # Check compatibility
for ax in axis: new_order = numerator.owner.op.new_order
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims:
_logger.warn('WARNING: Your current code is fine, but'
' Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and'
' cfc6322e5ad4 (2010-08-03) would '
'have given an incorrect result. '
'To disable this warning, set the Theano'
' flag warn.sum_div_dimshuffle_bug to'
' False.')
if denominator.owner and isinstance(denominator.owner.op,
T.DimShuffle):
thing_dimshuffled = denominator.owner.inputs[0]
new_order = denominator.owner.op.new_order
# print 'new_order =', new_order
# check compatibility
compatible_dims = True compatible_dims = True
for ax in axis: for ax in axis:
# print 'ax =', ax
# print 'len(new_order) =', len(new_order)
# print 'new_order[ax] =', new_order[ax]
if len(new_order) <= ax or new_order[ax] != 'x': if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False compatible_dims = False
break break
if compatible_dims: if compatible_dims:
# print 'getting denom out' _logger.warn('WARNING: Your current code is fine, but'
# Keep needed dimensions for new dimshuffle ' Theano versions between '
new_new_order = list(ax for i, ax in enumerate(new_order) 'rev. 3bd9b789f5e8 (2010-06-16) and'
if i not in axis or ax != 'x') ' cfc6322e5ad4 (2010-08-03) would '
# print 'new_new_order =', new_new_order 'have given an incorrect result. '
# Remove useless rebroadcast axes 'To disable this warning, set the Theano'
while len(new_new_order) > 0 and new_new_order[0] == 'x': ' flag warn.sum_div_dimshuffle_bug to'
del new_new_order[0] ' False.')
# print 'new_new_order =', new_new_order
if denominator.owner and isinstance(denominator.owner.op,
if all(i == e for i, e in enumerate(new_new_order)): T.DimShuffle):
new_denom = thing_dimshuffled dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order
compatible_dims = []
incompatible_dims = []
for ax in axis:
if (ax < len(dimshuffle_order) and
dimshuffle_order[ax] == 'x'):
compatible_dims.append(ax)
else: else:
if config.warn.sum_div_dimshuffle_bug: incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(
[1 for c_ax in compatible_dims if c_ax < ic_ax]))
if len(compatible_dims) > 0:
optimized_dimshuffle_order = list(
ax for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != 'x'))
# Removing leading 'x' (since it will be done automatically)
while (len(optimized_dimshuffle_order) > 0 and
optimized_dimshuffle_order[0] == 'x'):
del optimized_dimshuffle_order[0]
# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in
enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = T.DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order)(dimshuffle_input)
if (config.warn.sum_div_dimshuffle_bug and
isinstance(node.op, T.Sum)):
_logger.warn('WARNING: Your current code is fine,' _logger.warn('WARNING: Your current code is fine,'
' but Theano versions between ' ' but Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and' 'rev. 3bd9b789f5e8 (2010-06-16) and'
...@@ -4212,12 +4229,28 @@ def local_sum_div_dimshuffle(node): ...@@ -4212,12 +4229,28 @@ def local_sum_div_dimshuffle(node):
'warn.sum_div_dimshuffle_bug' 'warn.sum_div_dimshuffle_bug'
' to False.') ' to False.')
new_denom = T.DimShuffle( if isinstance(node.op, T.Sum):
thing_dimshuffled.type.broadcastable, op_on_compatible_dims = T.sum(
new_new_order)(thing_dimshuffled) numerator, axis=compatible_dims)
return [T.true_div(node.op(numerator), new_denom)] div_op = T.true_div(
# else: op_on_compatible_dims,
# print 'incompatible dims:', axis, new_order optimized_dimshuffle)
op_on_incompatible_dims = T.sum(
div_op,
axis=reordered_incompatible_dims)
elif isinstance(node.op, T.elemwise.Prod):
op_on_compatible_dims = T.prod(
numerator, axis=compatible_dims)
dtype = numerator.dtype
div_op = T.true_div(
op_on_compatible_dims,
(optimized_dimshuffle **
T.prod([numerator.shape[ax].astype(dtype)
for ax in compatible_dims])))
op_on_incompatible_dims = T.prod(
div_op,
axis=reordered_incompatible_dims)
return [op_on_incompatible_dims]
@register_canonicalize @register_canonicalize
......
...@@ -59,6 +59,7 @@ from theano.tensor import ( ...@@ -59,6 +59,7 @@ from theano.tensor import (
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.compile.mode import optdb from theano.compile.mode import optdb
from theano.compile import Mode
mode_opt = theano.config.mode mode_opt = theano.config.mode
if mode_opt == 'FAST_COMPILE': if mode_opt == 'FAST_COMPILE':
...@@ -4919,7 +4920,7 @@ class T_local_reduce(unittest.TestCase): ...@@ -4919,7 +4920,7 @@ class T_local_reduce(unittest.TestCase):
theano.config.warn.reduce_join = old theano.config.warn.reduce_join = old
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_prod_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize') self.mode = theano.compile.get_default_mode().including('canonicalize')
...@@ -4987,6 +4988,104 @@ class T_local_sum_dimshuffle(unittest.TestCase): ...@@ -4987,6 +4988,104 @@ class T_local_sum_dimshuffle(unittest.TestCase):
config.warn.sum_sum_bug, config.warn.sum_div_dimshuffle_bug =\ config.warn.sum_sum_bug, config.warn.sum_div_dimshuffle_bug =\
backup backup
def test_local_prod_div_dimshuffle(self):
a = T.matrix('a')
b = T.vector('b')
c = T.tensor3('c')
e = T.matrix('e')
d = T.scalar('d')
prod = T.prod
prods = [
prod(a / d),
prod(a / d.dimshuffle('x', 'x')),
prod(a / d.dimshuffle('x', 'x'), axis=0),
prod(a / d.dimshuffle('x', 'x'), axis=1),
prod(b / d),
prod(b / d.dimshuffle('x')),
prod(c / d),
prod(c / d.dimshuffle('x', 'x', 'x')),
prod(c / d.dimshuffle('x', 'x', 'x'), axis=0),
prod(c / d.dimshuffle('x', 'x', 'x'), axis=1),
prod(c / d.dimshuffle('x', 'x', 'x'), axis=2),
prod(a / b, axis=0),
prod(a / b.dimshuffle(0, 'x'), axis=1),
prod(a.dimshuffle(0, 1) / b.dimshuffle(0, 'x'), axis=1),
prod(a.dimshuffle(1, 0) / b.dimshuffle(0, 'x'), axis=1),
prod(c / a, axis=0),
prod(c / a.dimshuffle(1, 0), axis=0),
prod(c / a.dimshuffle(0, 'x', 1), axis=1),
prod(c / a.dimshuffle(1, 'x', 0), axis=1),
prod(c / a.dimshuffle(0, 1, 'x'), axis=2),
prod(c / a.dimshuffle(1, 0, 'x'), axis=2),
prod(c / b, axis=0),
prod(c / b, axis=1),
prod(c / b, axis=(0, 1)),
prod(c / b.dimshuffle(0, 'x'), axis=0),
prod(c / b.dimshuffle(0, 'x'), axis=2),
prod(c / b.dimshuffle(0, 'x'), axis=(0, 2)),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=1),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=2),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(1, 2)),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(0, 1)),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(1, 0)),
prod(prod(c, axis=0) / b, axis=0),
prod(prod(c, axis=1) / b, axis=0)]
rng = numpy.random.RandomState(utt.fetch_seed())
a_val = rng.randn(2, 2).astype(config.floatX)
b_val = rng.randn(2).astype(config.floatX)
c_val = rng.randn(2, 2, 2).astype(config.floatX)
d_val = numpy.asarray(rng.randn(), config.floatX)
default_mode = theano.compile.mode.get_default_mode()
# FusionOptimizer is included to make sure that expected_outer_operator
# remains the same for all optimization modes.
mode_with_opt = default_mode.including('local_sum_prod_div_dimshuffle',
'FusionOptimizer')
mode_without_opt = default_mode.excluding('local_sum_prod_div_dimshuffle')
# Numerical tests: tests whether the numerical values with and without
# optimizer are equal or not.
for i, s in enumerate(prods):
f = theano.function([a, b, c, d], s,
on_unused_input='ignore',
mode=mode_without_opt)
g = theano.function([a, b, c, d], s,
on_unused_input='ignore',
mode=mode_with_opt)
utt.assert_allclose(f(a_val, b_val, c_val, d_val),
g(a_val, b_val, c_val, d_val))
# Logical tests: tests whether the optimizer has been appplied or not
# by checking graph structure.
prods = [
prod(a / e),
prod(a / d),
prod(a / d.dimshuffle('x', 'x')),
prod(c / d.dimshuffle('x', 'x', 'x'), axis=1),
prod(a.dimshuffle(1, 0) / b.dimshuffle(0, 'x'), axis=1),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(1, 0)),
prod(prod(c, axis=1) / b, axis=0),
prod(prod(c, axis=(1, 2)) / b, axis=0)]
expected_outer_operator = [theano.scalar.basic.Mul,
theano.scalar.basic.Composite,
theano.scalar.basic.Composite,
theano.scalar.basic.TrueDiv,
theano.scalar.basic.Composite,
theano.scalar.basic.Mul,
theano.scalar.basic.Composite,
theano.scalar.basic.Mul]
for i, s in enumerate(prods):
g = theano.function([a, b, c, d, e], s,
on_unused_input='ignore',
mode=mode_with_opt)
assert isinstance(g.maker.fgraph.toposort()[-1].op.scalar_op,
expected_outer_operator[i])
# TODO: # TODO:
# test_local_sum_prod_dimshuffle (a * b * c) # test_local_sum_prod_dimshuffle (a * b * c)
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d)) # test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论