提交 6c06233a authored 作者: Frederic's avatar Frederic

Fix test by not introducing the alloc when not needed

上级 3603fce6
......@@ -85,7 +85,6 @@ def in2out(*local_opts, **kwargs):
else:
local_opts, = local_opts
if not name:
#import pdb;pdb.set_trace()
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='in_to_out',
......@@ -1211,7 +1210,7 @@ class ShapeFeature(object):
opy = dy.owner.op
if not (opx.i == opy.i):
return False
# FB I'm not sure is this handle correctly constants.
# FB I'm not sure if this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]:
continue
# To be sure to cover all case, call equal_computation.
......@@ -1698,7 +1697,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op = node.inputs[assert_op_idx]
cmp_op = assert_op
new_i = []
same_shape = node.fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP)
......@@ -1708,7 +1707,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)):
and not same_shape(i, cmp_op)):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
......@@ -1719,7 +1718,9 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)):
and not all([same_shape(i, cmp_op, idx, idx)
for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]])):
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
......
......@@ -2832,7 +2832,11 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self.tens = T.tensor3('tens', dtype=self.dtype)
self.alloc_wo_dep = T.alloc(self.vec, 2, 2)
self.alloc_wo_dep_broad = T.alloc(self.vec, 1, 2)
self.alloc_w_dep = T.alloc(self.vec, *self.mat.shape)
self.alloc_w_dep_broad = T.alloc(self.vec, 1, *self.mat.shape)
self.alloc_w_dep_broad2 = T.alloc(self.vec, self.mat.shape[0],
self.mat.shape[1], 1)
self.alloc_w_dep_tens = T.alloc(
self.vec,
self.tens.shape[0],
......@@ -2879,6 +2883,15 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# Optimization on alloc with assert and broadcast
func = function(
[self.vec, self.mat],
self.alloc_wo_dep_broad + self.mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on alloc without assert
func = function(
[self.vec, self.mat],
......@@ -2897,6 +2910,24 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Optimization on alloc without assert and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad + self. mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Not optimized case on alloc and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad2 + self. mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
def test_remove_alloc_w_dimshuffle(self):
# No optimization on dimshuffle with assert
func = function(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论