提交 faa267f1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add local_rebroadcast_lift, modelled after local_dimshuffle_lift

上级 09df3d2d
......@@ -711,6 +711,35 @@ def local_useless_rebroadcast(node):
if numpy.all(x.broadcastable == node.outputs[0].broadcastable):
return [x]
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Rebroadcast])
def local_rebroadcast_lift(node):
"""
"Lifts Rebroadcast through unary Elemwise operations,
and merges consecutive Rebroadcasts.
Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x))
Rebroadcast(Rebroadcast(x)) => Rebroadcast(x)
"""
op = node.op
if not isinstance(op, T.Rebroadcast):
return False
input = node.inputs[0]
inode = input.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
if len(input.clients)==1:
rval = inode.op.make_node(T.Rebroadcast(*op.axis.items())(inode.inputs[0])).outputs
return rval
if inode and isinstance(inode.op, T.Rebroadcast):
# the "axis" specification in the outer Rebroadcast overrides
# the axis of the inner one
axis = inode.op.axis.copy()
axis.update(op.axis)
iinput = inode.inputs[0]
rval = [T.Rebroadcast(*axis.items())(iinput)]
return rval
##################
......
......@@ -1126,14 +1126,25 @@ def test_local_mul_specialize():
assert nodes == [T.mul]
def test_local_useless_rebroadcast():
v1 = T.vector()
v2 = T.vector()
j = T.join(0, v1, v2)
f = theano.function([v1, v2], j)
f([1,2], [3,4,5])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, T.Rebroadcast)]) == 0
class T_Rebroadcast(unittest.TestCase):
def test_local_useless_rebroadcast(self):
v1 = T.vector()
v2 = T.vector()
j = T.join(0, v1, v2)
f = theano.function([v1, v2], j)
f([1,2], [3,4,5])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, T.Rebroadcast)]) == 0
def test_rebroadcast_rebroadcast(self):
m = T.matrix()
s = T.addbroadcast(m, 0, 1)
v = T.unbroadcast(s, 1)
f = theano.function([m], v)
f([[76]])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance]) == 1
if __name__ == '__main__':
# unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论