提交 c74da33d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3151 from carriepl/local_sum_mul_by_scalar

[BUG] Fix invalid optimization for product of multiplication
......@@ -3874,8 +3874,8 @@ def local_sum_prod_mul_by_scalar(node):
or
prod(scalar * smth) -> scalar * prod(smth)
prod(-smth) -> -prod(smth)
prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
prod(-smth) -> -1 ** size(smth) * prod(smth)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
......@@ -3886,24 +3886,39 @@ def local_sum_prod_mul_by_scalar(node):
scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if scalars:
if len(scalars) > 1:
if len(non_scalars) > 1:
return [T.mul(T.mul(*scalars),
node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(T.mul(*scalars),
node.op(non_scalars[0]))]
else:
return [T.mul(*scalars)]
else:
if len(non_scalars) > 1:
return [T.mul(scalars[0],
node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(scalars[0], node.op(non_scalars[0]))]
else:
return [scalars[0]]
if len(scalars) == 0:
# Nothing to optimize here
return
# Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0:
new_op_input_nb_elements = 1
new_op_output = 1
elif len(non_scalars) == 1:
new_op_input_nb_elements = T.prod(non_scalars[0].shape)
new_op_output = node.op(non_scalars[0])
else:
new_op_input = T.mul(*non_scalars)
new_op_input_nb_elements = T.prod(new_op_input.shape)
new_op_output = node.op(new_op_input)
# If node.op is a T.elemwise.Prod, then the scalars need to be
# raised to the power of the number of elements in the input
# to the Prod
if (isinstance(node.op, T.elemwise.Prod) and
new_op_input_nb_elements != 1):
scalars = [s ** new_op_input_nb_elements for s in scalars]
# Scale the output of the op by the scalars and return as
# replacement for the original output
mul_inputs = scalars
if new_op_input_nb_elements != 1:
mul_inputs.append(new_op_output)
return [T.mul(*mul_inputs)]
if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(node_inps.owner.inputs[0]))]
......
......@@ -4493,12 +4493,100 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
class T_local_sum_prod(unittest.TestCase):
"""
Test sum/prod opts in opt.py
Test sum/prod opts in opt.py
"""
def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize',
'specialize')
def test_local_sum_prod_mul_by_scalar(self):
# Test the optimization local_sum_prod_mul_by_scalar for both Sum and
# Prod ops in six cases each :
# 1-the inputs to the mul contain a scalar and no non-scalar
# 2-the inputs to the mul contain a scalar and one non-scalar
# 3-the inputs to the mul contain a scalar and two non-scalars
# 4-the inputs to the mul contain two scalars and no non-scalar
# 5-the inputs to the mul contain two scalars and one non-scalar
# 6-the inputs to the mul contain two scalars and two non-scalars
vect = T.dvector()
mat = T.dmatrix()
scalar1 = T.dscalar()
scalar2 = T.dscalar()
v_val = numpy.random.rand(2)
m_val = numpy.random.rand(2, 2)
s1_val = numpy.random.rand()
s2_val = numpy.random.rand()
def test_reduction_opt(inputs, inputs_val, reduction_op,
expected_output, nb_expected_sum_nodes):
mul_out = T.mul(*inputs)
f = theano.function(inputs, reduction_op()(mul_out),
mode=self.mode)
out = f(*inputs_val)
utt.assert_allclose(out, expected_output)
# Ensure that the optimization has been applied properly by
# ensuring that the optimized graph contains the expected number
# of apply nodes for the sum op
prod_nodes = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, reduction_op)]
assert len(prod_nodes) == nb_expected_sum_nodes
# Test sum
# Case 1
test_reduction_opt([scalar1], [s1_val], T.Sum, s1_val, 0)
# Case 2
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.Sum,
s1_val * v_val.sum(), 1)
# Case 3
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val], T.Sum,
s1_val * (v_val * m_val).sum(), 1)
# Case 4
test_reduction_opt([scalar1, scalar2], [s1_val, s2_val], T.Sum,
s1_val * s2_val, 0)
# Case 5
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.Sum, s1_val * s2_val * v_val.sum(), 1)
# Case 6
test_reduction_opt([vect, mat, scalar1, scalar2],
[v_val, m_val, s1_val, s2_val], T.Sum,
s1_val * s2_val * (v_val * m_val).sum(), 1)
# Test prod
# Case 1
test_reduction_opt([scalar1], [s1_val], T.elemwise.Prod, s1_val, 0)
# Case 2
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod,
(s1_val * v_val).prod(), 2)
# Case 3
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
T.elemwise.Prod, (s1_val * v_val * m_val).prod(), 2)
# Case 4
test_reduction_opt([scalar1, scalar2], [s1_val, s2_val],
T.elemwise.Prod, s1_val * s2_val, 0)
# Case 5
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.elemwise.Prod, (s1_val * s2_val * v_val).prod(),
2)
# Case 6
test_reduction_opt([vect, mat, scalar1, scalar2],
[v_val, m_val, s1_val, s2_val], T.elemwise.Prod,
(s1_val * s2_val * v_val * m_val).prod(), 2)
def test_local_sum_prod_all_to_none(self):
a = T.tensor3()
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
......@@ -5420,7 +5508,7 @@ class TestIntDivByOne(unittest.TestCase):
if isinstance(node.op, T.elemwise.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.IntDiv)]
assert len(divs) == 0
def test2(self):
"""Simple test case for removing dividing by 1"""
y = T.tensor4('y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论