提交 35d1e5cf authored 作者: Li Yao's avatar Li Yao

local opt: Alloc(Alloc(m, x, 1), x, y) -> Alloc(m, x, y)

Conflicts: theano/tensor/tests/test_opt.py opt support arbitary dim broadcasting proper tests better alloc remove conficts pep8 fix
上级 5653689f
......@@ -5887,3 +5887,32 @@ register_canonicalize(gof.OpRemove(theano.gradient.disconnected_grad_),
def local_grad_clip(node):
if isinstance(node.op, theano.gradient.GradClip):
return node.inputs
@register_canonicalize
@register_stabilize
@register_specialize
@gof.local_optimizer([T.Alloc])
def local_merge_alloc(node):
# This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
if not isinstance(node.op, T.Alloc):
return False
if not node.inputs[0].owner or not isinstance(
node.inputs[0].owner.op, T.Alloc):
return False
inputs_outer = node.inputs
inputs_inner = node.inputs[0].owner.inputs
dims_outer = inputs_outer[1:]
dims_inner = inputs_inner[1:]
dims_outer_rev = dims_outer[::-1]
dims_inner_rev = dims_inner[::-1]
# check if the pattern of broadcasting is matched, in the reversed ordering
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev):
if dim_inner != dim_outer:
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
pass
else:
return False
return [T.alloc(inputs_inner[0], *dims_outer)]
......@@ -5443,6 +5443,32 @@ class TestIntDivByOne(unittest.TestCase):
assert len(divs) == 0
def test_local_merge_alloc():
x = T.iscalar('x')
y = T.iscalar('y')
z = T.iscalar('z')
w = T.iscalar('w')
m = T.fscalar('m')
# case 1
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, 1, y, 1, 1), x, y, z, w)
f = theano.function([m, x, y, z, w], output)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
o = f(0., 1, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
# case 2
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y, z, w)
f = theano.function([m, x, y, z, w], output)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, T.Alloc)
o = f(0., 1, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
if __name__ == '__main__':
t = TestMakeVector('setUp')
t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论