提交 3b0f4cdb authored 作者: Frederic Bastien's avatar Frederic Bastien

Make local_merge_alloc apply in more case

上级 36d1eca8
...@@ -5897,6 +5897,7 @@ def local_merge_alloc(node): ...@@ -5897,6 +5897,7 @@ def local_merge_alloc(node):
# This opt takes care of several cases: # This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
if not isinstance(node.op, T.Alloc): if not isinstance(node.op, T.Alloc):
return False return False
if not node.inputs[0].owner or not isinstance( if not node.inputs[0].owner or not isinstance(
...@@ -5912,10 +5913,13 @@ def local_merge_alloc(node): ...@@ -5912,10 +5913,13 @@ def local_merge_alloc(node):
# The reverse ordering is needed when an Alloc add an implicit new # The reverse ordering is needed when an Alloc add an implicit new
# broadcasted dimensions to its inputs[0]. Eg: # broadcasted dimensions to its inputs[0]. Eg:
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
i = 0
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev): for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev):
if dim_inner != dim_outer: if dim_inner != dim_outer:
if isinstance(dim_inner, Constant) and dim_inner.data == 1: if isinstance(dim_inner, Constant) and dim_inner.data == 1:
pass pass
else: else:
return False dims_outer[-1 - i] = assert_op(dim_outer,
T.eq(dim_outer, dim_inner))
i += 1
return [T.alloc(inputs_inner[0], *dims_outer)] return [T.alloc(inputs_inner[0], *dims_outer)]
...@@ -13,6 +13,7 @@ import numpy ...@@ -13,6 +13,7 @@ import numpy
from six.moves import xrange from six.moves import xrange
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import assert_raises
from numpy.testing import dec from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest from numpy.testing.noseclasses import KnownFailureTest
...@@ -5451,6 +5452,7 @@ def test_local_merge_alloc(): ...@@ -5451,6 +5452,7 @@ def test_local_merge_alloc():
x = T.iscalar('x') x = T.iscalar('x')
y = T.iscalar('y') y = T.iscalar('y')
y2 = T.iscalar('y2')
z = T.iscalar('z') z = T.iscalar('z')
w = T.iscalar('w') w = T.iscalar('w')
m = T.fscalar('m') m = T.fscalar('m')
...@@ -5474,6 +5476,20 @@ def test_local_merge_alloc(): ...@@ -5474,6 +5476,20 @@ def test_local_merge_alloc():
o = f(0., 1, 2, 3, 4) o = f(0., 1, 2, 3, 4)
assert o.shape == (1, 2, 3, 4) assert o.shape == (1, 2, 3, 4)
# case 3
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) ->
# Alloc(m, x, assert(y1, y1==y2), z, w)
output = T.alloc(T.alloc(m, y, 1, 1), x, y2, z, w)
f = theano.function([m, x, y, y2, z, w], output, mode=opt_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert isinstance(topo[-2].op, T.opt.Assert)
assert isinstance(topo[-1].op, T.Alloc)
o = f(0., 1, 2, 2, 3, 4)
assert o.shape == (1, 2, 3, 4)
assert_raises((AssertionError, ValueError), f, 0., 1, 2, 5, 3, 4)
if __name__ == '__main__': if __name__ == '__main__':
t = TestMakeVector('setUp') t = TestMakeVector('setUp')
t.setUp() t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论