提交 3b0f4cdb authored 作者: Frederic Bastien's avatar Frederic Bastien

Make local_merge_alloc apply in more case

上级 36d1eca8
......@@ -5897,6 +5897,7 @@ def local_merge_alloc(node):
# This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
if not isinstance(node.op, T.Alloc):
return False
if not node.inputs[0].owner or not isinstance(
......@@ -5912,10 +5913,13 @@ def local_merge_alloc(node):
# The reverse ordering is needed when an Alloc add an implicit new
# broadcasted dimensions to its inputs[0]. Eg:
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
i = 0
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev):
if dim_inner != dim_outer:
if isinstance(dim_inner, Constant) and dim_inner.data == 1:
pass
else:
return False
dims_outer[-1 - i] = assert_op(dim_outer,
T.eq(dim_outer, dim_inner))
i += 1
return [T.alloc(inputs_inner[0], *dims_outer)]
......@@ -13,6 +13,7 @@ import numpy
from six.moves import xrange
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
from nose.tools import assert_raises
from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest
......@@ -5448,9 +5449,10 @@ def test_local_merge_alloc():
# otherwise, FAST_COMPILE fails.
default_mode = theano.compile.mode.get_default_mode()
opt_mode = default_mode.including("local_merge_alloc")
x = T.iscalar('x')
y = T.iscalar('y')
y2 = T.iscalar('y2')
z = T.iscalar('z')
w = T.iscalar('w')
m = T.fscalar('m')
......@@ -5474,6 +5476,20 @@ def test_local_merge_alloc():
o = f(0., 1, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
# case 3
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) ->
# Alloc(m, x, assert(y1, y1==y2), z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y2, z, w)
f = theano.function([m, x, y, y2, z, w], output, mode=opt_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[-2].op, T.opt.Assert)
assert isinstance(topo[-1].op, T.Alloc)
o = f(0., 1, 2, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
assert_raises((AssertionError, ValueError), f, 0., 1, 2, 5, 3, 4)
if __name__ == '__main__':
t = TestMakeVector('setUp')
t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论