提交 6c06233a authored 作者: Frederic's avatar Frederic

Fix test by not introducing the alloc when not needed

上级 3603fce6
...@@ -85,7 +85,6 @@ def in2out(*local_opts, **kwargs): ...@@ -85,7 +85,6 @@ def in2out(*local_opts, **kwargs):
else: else:
local_opts, = local_opts local_opts, = local_opts
if not name: if not name:
#import pdb;pdb.set_trace()
name = local_opts.__name__ name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts, ret = opt.TopoOptimizer(local_opts,
order='in_to_out', order='in_to_out',
...@@ -1211,7 +1210,7 @@ class ShapeFeature(object): ...@@ -1211,7 +1210,7 @@ class ShapeFeature(object):
opy = dy.owner.op opy = dy.owner.op
if not (opx.i == opy.i): if not (opx.i == opy.i):
return False return False
# FB I'm not sure is this handle correctly constants. # FB I'm not sure if this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]: if dx.owner.inputs[0] == dy.owner.inputs[0]:
continue continue
# To be sure to cover all case, call equal_computation. # To be sure to cover all case, call equal_computation.
...@@ -1698,7 +1697,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1698,7 +1697,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op = node.inputs[assert_op_idx] assert_op = node.inputs[assert_op_idx]
cmp_op = assert_op cmp_op = assert_op
new_i = [] new_i = []
same_shape = node.fgraph.shape_feature.same_shape
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP) if (i.owner and isinstance(i.owner.op, AllocOP)
...@@ -1708,7 +1707,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1708,7 +1707,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)): and not same_shape(i, cmp_op)):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx]) *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
...@@ -1719,7 +1718,9 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1719,7 +1718,9 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
elif i.owner and dimshuffled_alloc(i): elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim assert i.type.ndim == cmp_op.type.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)): and not all([same_shape(i, cmp_op, idx, idx)
for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]])):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx]) *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
......
...@@ -2832,7 +2832,11 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -2832,7 +2832,11 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self.tens = T.tensor3('tens', dtype=self.dtype) self.tens = T.tensor3('tens', dtype=self.dtype)
self.alloc_wo_dep = T.alloc(self.vec, 2, 2) self.alloc_wo_dep = T.alloc(self.vec, 2, 2)
self.alloc_wo_dep_broad = T.alloc(self.vec, 1, 2)
self.alloc_w_dep = T.alloc(self.vec, *self.mat.shape) self.alloc_w_dep = T.alloc(self.vec, *self.mat.shape)
self.alloc_w_dep_broad = T.alloc(self.vec, 1, *self.mat.shape)
self.alloc_w_dep_broad2 = T.alloc(self.vec, self.mat.shape[0],
self.mat.shape[1], 1)
self.alloc_w_dep_tens = T.alloc( self.alloc_w_dep_tens = T.alloc(
self.vec, self.vec,
self.tens.shape[0], self.tens.shape[0],
...@@ -2879,6 +2883,15 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -2879,6 +2883,15 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self._verify_alloc_count(func, 0) self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1) self._verify_assert_count(func, 1)
# Optimization on alloc with assert and broadcast
func = function(
[self.vec, self.mat],
self.alloc_wo_dep_broad + self.mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on alloc without assert # No optimization on alloc without assert
func = function( func = function(
[self.vec, self.mat], [self.vec, self.mat],
...@@ -2897,6 +2910,24 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -2897,6 +2910,24 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self._verify_alloc_count(func, 0) self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0) self._verify_assert_count(func, 0)
# Optimization on alloc without assert and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad + self. mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Not optimized case on alloc and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad2 + self. mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
def test_remove_alloc_w_dimshuffle(self): def test_remove_alloc_w_dimshuffle(self):
# No optimization on dimshuffle with assert # No optimization on dimshuffle with assert
func = function( func = function(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论