提交 c3f09c73 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2586 from nouiz/same_shape

Better same_shape implementation
...@@ -85,7 +85,6 @@ def in2out(*local_opts, **kwargs): ...@@ -85,7 +85,6 @@ def in2out(*local_opts, **kwargs):
else: else:
local_opts, = local_opts local_opts, = local_opts
if not name: if not name:
#import pdb;pdb.set_trace()
name = local_opts.__name__ name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts, ret = opt.TopoOptimizer(local_opts,
order='in_to_out', order='in_to_out',
...@@ -1180,7 +1179,7 @@ class ShapeFeature(object): ...@@ -1180,7 +1179,7 @@ class ShapeFeature(object):
same shape. same shape.
dim_x and dim_y are optional. If used, they should be an index dim_x and dim_y are optional. If used, they should be an index
to compare only 1 shape of x or y. to compare only 1 dimension of x and y.
""" """
sx = self.shape_of[x] sx = self.shape_of[x]
...@@ -1193,6 +1192,9 @@ class ShapeFeature(object): ...@@ -1193,6 +1192,9 @@ class ShapeFeature(object):
sy = [sy[dim_y]] sy = [sy[dim_y]]
assert len(sx) == len(sy) assert len(sx) == len(sy)
# We look on each dimensions we want to compare.
# If any of them can't be asserted to be equal, return False.
# Otherwise, we return True at the end.
for dx, dy in zip(sx, sy): for dx, dy in zip(sx, sy):
if dx is dy: if dx is dy:
continue continue
...@@ -1208,14 +1210,16 @@ class ShapeFeature(object): ...@@ -1208,14 +1210,16 @@ class ShapeFeature(object):
opy = dy.owner.op opy = dy.owner.op
if not (opx.i == opy.i): if not (opx.i == opy.i):
return False return False
# FB I'm not sure is this handle correctly constants. # FB I'm not sure if this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]: if dx.owner.inputs[0] == dy.owner.inputs[0]:
return True continue
# To be sure to cover all case, call equal_computation. # To be sure to cover all case, call equal_computation.
# Can't use theano.gof.graph.is_same_graph(dx, dy) # Can't use theano.gof.graph.is_same_graph(dx, dy)
# As it currently expect that dx and dy aren't in a FunctionGraph # As it currently expect that dx and dy aren't in a FunctionGraph
from theano.scan_module.scan_utils import equal_computations from theano.scan_module.scan_utils import equal_computations
return equal_computations([dx], [dy]) if not equal_computations([dx], [dy]):
return False
return True
class ShapeOptimizer(Optimizer): class ShapeOptimizer(Optimizer):
...@@ -1431,18 +1435,18 @@ def local_useless_elemwise(node): ...@@ -1431,18 +1435,18 @@ def local_useless_elemwise(node):
return [T.fill(node.inputs[0], return [T.fill(node.inputs[0],
T.constant(1.0, T.constant(1.0,
dtype=node.outputs[0].type.dtype))] dtype=node.outputs[0].type.dtype))]
if node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2: elif node.op.scalar_op == theano.scalar.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]: if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false # it is the same var in the graph. That will always be false
return [T.fill(node.inputs[0], return [T.fill(node.inputs[0],
T.constant(0.0, T.constant(0.0,
dtype=node.outputs[0].type.dtype))] dtype=node.outputs[0].type.dtype))]
if node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.mul and len(node.inputs) == 1:
return [node.inputs[0]] return [node.inputs[0]]
if node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1: elif node.op.scalar_op == theano.scalar.add and len(node.inputs) == 1:
return [node.inputs[0]] return [node.inputs[0]]
if (node.op.scalar_op == theano.scalar.identity elif (node.op.scalar_op == theano.scalar.identity
and len(node.inputs) == 1): and len(node.inputs) == 1):
return [node.inputs[0]] return [node.inputs[0]]
...@@ -1693,7 +1697,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1693,7 +1697,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op = node.inputs[assert_op_idx] assert_op = node.inputs[assert_op_idx]
cmp_op = assert_op cmp_op = assert_op
new_i = [] new_i = []
same_shape = node.fgraph.shape_feature.same_shape
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP) if (i.owner and isinstance(i.owner.op, AllocOP)
...@@ -1703,7 +1707,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1703,7 +1707,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)): and not same_shape(i, cmp_op)):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx]) *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
...@@ -1713,12 +1717,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1713,12 +1717,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# Remove Alloc in DimShuffle # Remove Alloc in DimShuffle
elif i.owner and dimshuffled_alloc(i): elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim assert i.type.ndim == cmp_op.type.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if theano.config.experimental.local_alloc_elemwise_assert:
and not node.fgraph.shape_feature.same_shape(i, cmp_op)): assert_cond = [T.eq(i.shape[idx], cmp_op.shape[idx])
assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]]) if not i.type.broadcastable[idx] and
not same_shape(i, cmp_op, idx, idx)]
if assert_cond:
assert_op = assert_(assert_op, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0] alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim: if alloc_input.ndim != i.owner.inputs[0].ndim:
# The alloc can add dimension to the value # The alloc can add dimension to the value
......
...@@ -2832,7 +2832,11 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -2832,7 +2832,11 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self.tens = T.tensor3('tens', dtype=self.dtype) self.tens = T.tensor3('tens', dtype=self.dtype)
self.alloc_wo_dep = T.alloc(self.vec, 2, 2) self.alloc_wo_dep = T.alloc(self.vec, 2, 2)
self.alloc_wo_dep_broad = T.alloc(self.vec, 1, 2)
self.alloc_w_dep = T.alloc(self.vec, *self.mat.shape) self.alloc_w_dep = T.alloc(self.vec, *self.mat.shape)
self.alloc_w_dep_broad = T.alloc(self.vec, 1, *self.mat.shape)
self.alloc_w_dep_broad2 = T.alloc(self.vec, self.mat.shape[0],
self.mat.shape[1], 1)
self.alloc_w_dep_tens = T.alloc( self.alloc_w_dep_tens = T.alloc(
self.vec, self.vec,
self.tens.shape[0], self.tens.shape[0],
...@@ -2879,6 +2883,15 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -2879,6 +2883,15 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self._verify_alloc_count(func, 0) self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1) self._verify_assert_count(func, 1)
# Optimization on alloc with assert and broadcast
func = function(
[self.vec, self.mat],
self.alloc_wo_dep_broad + self.mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on alloc without assert # No optimization on alloc without assert
func = function( func = function(
[self.vec, self.mat], [self.vec, self.mat],
...@@ -2897,6 +2910,24 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -2897,6 +2910,24 @@ class Test_local_elemwise_alloc(unittest.TestCase):
self._verify_alloc_count(func, 0) self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0) self._verify_assert_count(func, 0)
# Optimization on alloc without assert and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad + self. mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Not optimized case on alloc and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad2 + self. mat,
mode=self.fast_run_mode
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
def test_remove_alloc_w_dimshuffle(self): def test_remove_alloc_w_dimshuffle(self):
# No optimization on dimshuffle with assert # No optimization on dimshuffle with assert
func = function( func = function(
...@@ -5016,6 +5047,57 @@ class TestShape_i(utt.InferShapeTester): ...@@ -5016,6 +5047,57 @@ class TestShape_i(utt.InferShapeTester):
[admat_val], Shape_i) [admat_val], Shape_i)
class TestShapeFeature(unittest.TestCase):
def test_scalar(self):
x = scalar()
cst = T.constant(1).clone()
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector(self):
x = vector()
cst = T.constant(1).clone()
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
def test_vector2(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)
def test_vector_dim(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o, 0, 0)
def test_vector_dim_err(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = opt.ShapeFeature()
fgraph.attach_feature(shape_feature)
self.assertRaises(IndexError, shape_feature.same_shape, x, o, 1, 0)
self.assertRaises(IndexError, shape_feature.same_shape, x, o, 0, 1)
if __name__ == '__main__': if __name__ == '__main__':
t = TestMakeVector('setUp') t = TestMakeVector('setUp')
t.setUp() t.setUp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论