提交 c9d69119 authored 作者: abergeron's avatar abergeron

Merge pull request #2852 from kelvinxu/prod_opts

Prod opts [WIP]
...@@ -3849,17 +3849,22 @@ register_canonicalize(local_neg_to_mul) ...@@ -3849,17 +3849,22 @@ register_canonicalize(local_neg_to_mul)
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_mul_by_scalar(node): def local_sum_prod_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth) """sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth) sum(-smth) -> -sum(smth)
or
prod(scalar * smth) -> scalar * prod(smth)
prod(-smth) -> -prod(smth)
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
thing_summed, = node.inputs node_inps, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul: if node_inps.owner and node_inps.owner.op == T.mul:
terms = thing_summed.owner.inputs terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)] numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)] non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
...@@ -3881,8 +3886,8 @@ def local_sum_mul_by_scalar(node): ...@@ -3881,8 +3886,8 @@ def local_sum_mul_by_scalar(node):
return [T.mul(scalars[0], node.op(non_scalars[0]))] return [T.mul(scalars[0], node.op(non_scalars[0]))]
else: else:
return [scalars[0]] return [scalars[0]]
if thing_summed.owner and thing_summed.owner.op == T.neg: if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))] return [T.neg(node.op(node_inps.owner.inputs[0]))]
@register_specialize @register_specialize
...@@ -3989,64 +3994,68 @@ def local_sum_div_dimshuffle(node): ...@@ -3989,64 +3994,68 @@ def local_sum_div_dimshuffle(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_all_to_none(node): def local_sum_prod_all_to_none(node):
"""Sum{0,1,...N} -> Sum{}""" """Sum{0,1,...N} -> Sum{} or
if isinstance(node.op, T.Sum): Prod{0,1,...N} -> Prod{}
"""
if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
# if all the axes are named, then use None as a shorthand # if all the axes are named, then use None as a shorthand
# this permits more merging # this permits more merging
if node.op.axis is None: if node.op.axis is None:
return return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [T.Sum(axis=None, dtype=node.op.dtype)(node.inputs[0])] return [opt_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_sum(node): def local_op_of_op(node):
""" """
Sum(Sum()) -> Sum Prod(Prod()) -> single Prod()
or
Sum(Sum()) -> single Sum()
""" """
if isinstance(node.op, T.Sum): if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum):
summed, = node.inputs opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
node_inps, = node.inputs
out_dtype = node.op.dtype out_dtype = node.op.dtype
if len(summed.clients) == 1: # We manipulate the graph so this is done to make sure the opt
if (summed.owner and # doesn't affect other computations.
isinstance(summed.owner.op, T.Sum)): if len(node_inps.clients) == 1:
if (node_inps.owner and (isinstance(node_inps.owner.op, T.elemwise.Prod)
if summed.owner.op.axis is None: or isinstance(node_inps.owner.op, T.elemwise.Sum))):
# special case of local_cut_useless_reduce
return [T.Sum(None, dtype=out_dtype)( # check to see either the inner or outer prod is doing a
summed.owner.inputs[0])] # product over all axis, in which case we can remove it
if node.op.axis is None: if node_inps.owner.op.axis is None or node.op.axis is None:
# we're summing up everything anyway so lets return [opt_type(None, dtype=out_dtype)(
# do it all at once node_inps.owner.inputs[0])]
return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])] # figure out which axes were in the original sum
newaxis = list(tuple(node_inps.owner.op.axis))
newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input
# are preserved
for i in node.op.axis: for i in node.op.axis:
new_i = i new_i = i
for ii in summed.owner.op.axis: for ii in node_inps.owner.op.axis:
if new_i >= ii: if new_i >= ii:
new_i += 1 new_i += 1
assert new_i not in newaxis assert new_i not in newaxis
newaxis.append(new_i) newaxis.append(new_i)
assert len(newaxis) == len(list(summed.owner.op.axis) + assert len(newaxis) == len(list(node_inps.owner.op.axis) +
list(node.op.axis)) list(node.op.axis))
# The old bugged logic. We keep it there to generate a warning # The old bugged logic. We keep it there to generate a warning
# when we generated bad code. # when we generated bad code.
alldims = range(summed.owner.inputs[0].type.ndim) alldims = range(node_inps.owner.inputs[0].type.ndim)
alldims = [d for i, d in enumerate(alldims) if i alldims = [d for i, d in enumerate(alldims) if i
in summed.owner.op.axis] in node_inps.owner.op.axis]
alldims = [d for i, d in enumerate(alldims) alldims = [d for i, d in enumerate(alldims)
if i in node.op.axis] if i in node.op.axis]
newaxis_old = [i for i in newaxis_old = [i for i in
xrange(summed.owner.inputs[0].type.ndim) xrange(node_inps.owner.inputs[0].type.ndim)
if i not in alldims] if i not in alldims]
if (theano.config.warn.sum_sum_bug and if (theano.config.warn.sum_sum_bug and
...@@ -4065,8 +4074,9 @@ def local_sum_sum(node): ...@@ -4065,8 +4074,9 @@ def local_sum_sum(node):
"been fixed) set the theano flag " "been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.") "`warn.sum_sum_bug` to False.")
combined_sum = T.Sum(newaxis, dtype=out_dtype) combined = opt_type(newaxis, dtype=out_dtype)
return [combined_sum(summed.owner.inputs[0])] return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any, ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
T.elemwise.Sum, T.elemwise.Prod, T.elemwise.Sum, T.elemwise.Prod,
...@@ -4208,21 +4218,29 @@ def local_reduce_broadcastable(node): ...@@ -4208,21 +4218,29 @@ def local_reduce_broadcastable(node):
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_alloc(node): def local_opt_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)""" """ sum(alloc(constant,shapes...)) => constant*prod(shapes)
if isinstance(node.op, T.Sum): or
summed, = node.inputs prod(alloc(constant,shapes...)) => constant**prod(shapes)
if summed.owner and isinstance(summed.owner.op, T.Alloc): """
input = summed.owner.inputs[0] if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
shapes = summed.owner.inputs[1:] node_inps, = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc):
input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_scalar_constant_value(input) val = get_scalar_constant_value(input)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes) # check which type of op
if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * T.mul(*shapes)
else:
val = val.reshape(1)[0] ** T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)] return [T.cast(val, dtype=node.outputs[0].dtype)]
except NotScalarConstantError: except NotScalarConstantError:
pass pass
else: else:
...@@ -4233,7 +4251,10 @@ def local_sum_alloc(node): ...@@ -4233,7 +4251,10 @@ def local_sum_alloc(node):
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis] if i in node.op.axis]
if to_prod: if to_prod:
val *= T.mul(*to_prod) if isinstance(node.op, T.Sum):
val *= T.mul(*to_prod)
else:
val = val ** T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype), return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])] if i not in node.op.axis])]
......
...@@ -4459,21 +4459,33 @@ class test_local_remove_switch_const_cond(unittest.TestCase): ...@@ -4459,21 +4459,33 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
assert numpy.all(f(vx, vy) == vy) assert numpy.all(f(vx, vy) == vy)
class T_local_sum(unittest.TestCase): class T_local_sum_prod(unittest.TestCase):
"""
Test sum/prod opts in opt.py
"""
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize', self.mode = theano.compile.get_default_mode().including('canonicalize',
'specialize') 'specialize')
def test_local_sum_all_to_none(self): def test_local_sum_prod_all_to_none(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
# test sum
f = theano.function([a], a.sum(), mode=self.mode) f = theano.function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum()) assert numpy.allclose(f(input), input.sum())
# test prod
f = theano.function([a], a.prod(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.prod())
# test sum
f = theano.function([a], a.sum([0, 1, 2]), mode=self.mode) f = theano.function([a], a.sum([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum()) assert numpy.allclose(f(input), input.sum())
# test prod
f = theano.function([a], a.prod([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.prod())
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
...@@ -4484,7 +4496,7 @@ class T_local_sum(unittest.TestCase): ...@@ -4484,7 +4496,7 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_sum(self): def test_local_sum_sum_prod_prod(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1), dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1),
...@@ -4494,6 +4506,17 @@ class T_local_sum(unittest.TestCase): ...@@ -4494,6 +4506,17 @@ class T_local_sum(unittest.TestCase):
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
def my_prod(data, d, dd):
# This prod when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.prod(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.prod(d[1]).prod(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.prod(d).prod(dd[1]).prod(dd[0])
def my_sum(data, d, dd): def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions. # This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple): if not isinstance(d, tuple) and not isinstance(dd, tuple):
...@@ -4526,7 +4549,27 @@ class T_local_sum(unittest.TestCase): ...@@ -4526,7 +4549,27 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_alloc(self): # test prod
for d, dd in dims:
expected = my_prod(input, d, dd)
f = theano.function([a], a.prod(d).prod(dd), mode=self.mode)
assert numpy.allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = theano.function([a], a.prod(d).prod(dd).
prod(0), mode=self.mode)
assert numpy.allclose(f(input), input.prod(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = theano.function([a], a.prod(d).prod(None), mode=self.mode)
assert numpy.allclose(f(input), input.prod(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 1
f = theano.function([a], a.prod(None).prod(), mode=self.mode)
assert numpy.allclose(f(input), input.prod())
assert len(f.maker.fgraph.apply_nodes) == 1
def test_local_sum_prod_alloc(self):
a = T.dtensor3() a = T.dtensor3()
input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4), input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4),
dtype='float64') dtype='float64')
...@@ -4535,6 +4578,7 @@ class T_local_sum(unittest.TestCase): ...@@ -4535,6 +4578,7 @@ class T_local_sum(unittest.TestCase):
for t_like, n_like, nb_nodes in [(tensor.zeros_like, numpy.zeros_like, (1, 3, 3, 2)), for t_like, n_like, nb_nodes in [(tensor.zeros_like, numpy.zeros_like, (1, 3, 3, 2)),
(tensor.ones_like, numpy.ones_like, (5, 5, 5, 6))]: (tensor.ones_like, numpy.ones_like, (5, 5, 5, 6))]:
# test sum
f = theano.function([a], t_like(a).sum(None), mode=mode) f = theano.function([a], t_like(a).sum(None), mode=mode)
assert numpy.allclose(f(input), n_like(input).sum()) assert numpy.allclose(f(input), n_like(input).sum())
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0] assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
...@@ -4558,6 +4602,30 @@ class T_local_sum(unittest.TestCase): ...@@ -4558,6 +4602,30 @@ class T_local_sum(unittest.TestCase):
assert topo[-1].op == T.alloc assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo]) assert not any([isinstance(node.op, T.Sum) for node in topo])
# test prod
f = theano.function([a], t_like(a).prod(None), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod())
#assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
f = theano.function([a], t_like(a).prod([0, 1, 2]), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod())
#assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
for d in range(3):
f = theano.function([a], t_like(a).prod(d), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod(d))
#assert len(f.maker.fgraph.apply_nodes) == nb_nodes[1]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.elemwise.Prod) for node in topo])
for i in range(3):
f = theano.function([a], t_like(a).prod(i), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod(i))
#assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.elemwise.Prod) for node in topo])
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论