提交 2688f9f8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1644 from nouiz/gh-1122

Gh 1122
...@@ -566,6 +566,9 @@ def get_scalar_constant_value(v): ...@@ -566,6 +566,9 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs x, y = v.owner.inputs
return get_scalar_constant_value(y) return get_scalar_constant_value(y)
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would # fct is used too early in the optimization phase. This would
# mess with the stabilization optimization. # mess with the stabilization optimization.
......
...@@ -731,7 +731,12 @@ class ShapeFeature(object): ...@@ -731,7 +731,12 @@ class ShapeFeature(object):
return self.lscalar_one return self.lscalar_one
else: else:
# Do not call make_node for test_value # Do not call make_node for test_value
return Shape_i(i)(r) s = Shape_i(i)(r)
try:
s = get_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
def shape_tuple(self, r): def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r""" """Return a tuple of symbolic shape vars for tensor variable r"""
......
...@@ -5921,6 +5921,13 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -5921,6 +5921,13 @@ class T_get_scalar_constant_value(unittest.TestCase):
get_scalar_constant_value, get_scalar_constant_value,
mv[t()]) mv[t()])
def test_shape_i(self):
c = theano.tensor.constant(numpy.random.rand(3, 4))
s = opt.Shape_i(0)(c)
assert get_scalar_constant_value(s) == 3
s = opt.Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
......
...@@ -236,12 +236,12 @@ class test_canonize(unittest.TestCase): ...@@ -236,12 +236,12 @@ class test_canonize(unittest.TestCase):
fyv = theano._asarray(numpy.random.rand(*shp), dtype='float32') fyv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
fzv = theano._asarray(numpy.random.rand(*shp), dtype='float32') fzv = theano._asarray(numpy.random.rand(*shp), dtype='float32')
fvv = theano._asarray(numpy.random.rand(shp[0]), dtype= fvv = theano._asarray(numpy.random.rand(shp[0]), dtype=
'float32').reshape(1, shp[0]) 'float32').reshape(1, shp[0])
dxv = theano._asarray(numpy.random.rand(*shp), dtype='float64') dxv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dyv = theano._asarray(numpy.random.rand(*shp), dtype='float64') dyv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dzv = theano._asarray(numpy.random.rand(*shp), dtype='float64') dzv = theano._asarray(numpy.random.rand(*shp), dtype='float64')
dvv = theano._asarray(numpy.random.rand(shp[0]), dtype= dvv = theano._asarray(numpy.random.rand(shp[0]), dtype=
'float64').reshape(1, shp[0]) 'float64').reshape(1, shp[0])
cases = [ cases = [
(fx + fy, (fx, fy), (fxv, fyv), 1, 'float32'), (fx + fy, (fx, fy), (fxv, fyv), 1, 'float32'),
(fx * fy, (fx, fy), (fxv, fyv), 1, 'float32'), (fx * fy, (fx, fy), (fxv, fyv), 1, 'float32'),
...@@ -2571,6 +2571,15 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -2571,6 +2571,15 @@ class test_shapeoptimizer(unittest.TestCase):
f = theano.function([], out, mode=mode) f = theano.function([], out, mode=mode)
f() f()
def test_constant_merge(self):
"""This test the error in gh-1122 that is a caused by the
combination of merge optimizer and ShapeFeature.
"""
x = tensor.constant([0, 0])
y = x[1:]
x1 = x - tensor.join(0, y, y)
x1.eval()
def test_local_track_shape_i(self): def test_local_track_shape_i(self):
class IdentityNoShape(gof.Op): class IdentityNoShape(gof.Op):
'''Op that does not infer the output shape from the input one''' '''Op that does not infer the output shape from the input one'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论