提交 cd21e2d0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4906 from olimastro/master

Issue 4825
...@@ -1796,6 +1796,12 @@ def local_useless_alloc(node): ...@@ -1796,6 +1796,12 @@ def local_useless_alloc(node):
# We don't need to copy over any stack traces here # We don't need to copy over any stack traces here
return [input] return [input]
# Allow local_merge_alloc to do its work first
clients = getattr(output, 'clients', [])
for client, i in clients:
if client != "output" and isinstance(client.op, Alloc):
return
# Check if alloc adds a broadcastable dimension with shape 1. # Check if alloc adds a broadcastable dimension with shape 1.
output_shape = node.inputs[1:] output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0 num_dims_with_size_1_added_to_left = 0
......
...@@ -34,6 +34,7 @@ from theano.tensor.opt import ( ...@@ -34,6 +34,7 @@ from theano.tensor.opt import (
local_dimshuffle_lift, local_dimshuffle_lift,
local_useless_dimshuffle_in_reshape, local_useless_dimshuffle_in_reshape,
local_useless_alloc, local_useless_alloc,
local_merge_alloc,
local_greedy_distributor, local_greedy_distributor,
local_useless_reshape, local_useless_reshape,
local_reshape_to_dimshuffle, local_reshape_to_dimshuffle,
...@@ -6595,6 +6596,61 @@ def test_local_merge_alloc(): ...@@ -6595,6 +6596,61 @@ def test_local_merge_alloc():
assert_raises((AssertionError, ValueError), f, 0., 1, 2, 5, 3, 4) assert_raises((AssertionError, ValueError), f, 0., 1, 2, 5, 3, 4)
def test_local_useless_alloc():
useless_alloc = out2in(local_useless_alloc)
merge_alloc = out2in(local_merge_alloc)
x = T.iscalar('x')
y = T.iscalar('y')
y2 = T.iscalar('y2')
z = T.iscalar('z')
w = T.iscalar('w')
m = T.fscalar('m')
# case 1
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, 1, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
topo = g.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
# case 2
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y, z, w)
g = FunctionGraph([m, x, y, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
topo = g.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
# case 3
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) ->
# Alloc(m, x, assert(y1, y1==y2), z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y2, z, w)
g = FunctionGraph([m, x, y, y2, z, w], [output])
useless_alloc.optimize(g)
merge_alloc.optimize(g)
useless_alloc.optimize(g)
topo = g.toposort()
assert len(topo) == 3
assert isinstance(topo[-2].op, T.opt.Assert)
assert isinstance(topo[-1].op, T.Alloc)
if __name__ == '__main__': if __name__ == '__main__':
t = TestMakeVector('setUp') t = TestMakeVector('setUp')
t.setUp() t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论