提交 6477a500 authored 作者: Frederic Bastien's avatar Frederic Bastien

added optimization: now optimize sum(alloc(x,shps)

上级 9a694d42
......@@ -1789,6 +1789,37 @@ def local_cut_useless_reduce(node):
if summed.type == node.outputs[0].type:
return [summed]
@register_specialize
@gof.local_optimizer([])
def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
if isinstance(node.op, T.Sum):
summed, = node.inputs
if summed.owner and isinstance(summed.owner.op, T.Alloc):
input = summed.owner.inputs[0]
shapes = summed.owner.inputs[1:]
#import pdb;pdb.set_trace()
if node.op.axis is None or node.op.axis == tuple(range(input.ndim)):
try:
val = get_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0]*T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
except TypeError, e:
pass
else:
try:
val = get_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis]
if to_prod:
val *= T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in range(len(shapes)) if i not in node.op.axis])]
except TypeError, e:
pass
@gof.local_optimizer([T.mul])
def local_mul_to_neg(node):
if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
......
......@@ -1622,7 +1622,42 @@ class T_local_sum(unittest.TestCase):
f = theano.function([a],a.sum(None).sum(),mode=self.mode)
assert numpy.allclose(f(input),input.sum())
assert len(f.maker.env.nodes)==1
def test_local_sum_alloc(self):
a=T.dtensor3()
input=numpy.asarray(numpy.arange(2*3*4).reshape(2,3,4),dtype='float64')
mode = self.mode.including('specialize').excluding('fusion')
for t_like,n_like,nb_nodes in [(zeros_like,numpy.zeros_like,(0,3,3,2)),
(ones_like,numpy.ones_like,(5,5,5,6))]:
f = theano.function([a],t_like(a).sum(None),mode=mode)
assert numpy.allclose(f(input),n_like(input).sum())
assert len(f.maker.env.nodes)==nb_nodes[0]
f = theano.function([a],t_like(a).sum([0,1,2]),mode=mode)
assert numpy.allclose(f(input),n_like(input).sum())
assert len(f.maker.env.nodes)==nb_nodes[0]
for d in range(3):
f = theano.function([a],t_like(a).sum(d),mode=mode)
assert numpy.allclose(f(input),n_like(input).sum(d))
assert len(f.maker.env.nodes)==nb_nodes[1]
assert f.maker.env.toposort()[-1].op==T.alloc
for i in range(3):
f = theano.function([a],t_like(a).sum(i),mode=mode)
assert numpy.allclose(f(input),n_like(input).sum(i))
assert len(f.maker.env.nodes)==nb_nodes[2]
assert f.maker.env.toposort()[-1].op==T.alloc
for d, dd in [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)]:
f = theano.function([a],t_like(a).sum(d).sum(dd),mode=mode)
print f.maker.env.toposort()
assert numpy.allclose(f(input),n_like(input).sum(d).sum(dd))
assert len(f.maker.env.nodes)==nb_nodes[3]
assert f.maker.env.toposort()[-1].op==T.alloc
class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论