提交 2688f9f8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1644 from nouiz/gh-1122

Gh 1122
......@@ -566,6 +566,9 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs
return get_scalar_constant_value(y)
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
# mess with the stabilization optimization.
......
......@@ -731,7 +731,12 @@ class ShapeFeature(object):
return self.lscalar_one
else:
# Do not call make_node for test_value
return Shape_i(i)(r)
s = Shape_i(i)(r)
try:
s = get_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
def shape_tuple(self, r):
"""Return a tuple of symbolic shape vars for tensor variable r"""
......
......@@ -5921,6 +5921,13 @@ class T_get_scalar_constant_value(unittest.TestCase):
get_scalar_constant_value,
mv[t()])
def test_shape_i(self):
c = theano.tensor.constant(numpy.random.rand(3, 4))
s = opt.Shape_i(0)(c)
assert get_scalar_constant_value(s) == 3
s = opt.Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4
class T_as_tensor_variable(unittest.TestCase):
"""
......
......@@ -2571,6 +2571,15 @@ class test_shapeoptimizer(unittest.TestCase):
f = theano.function([], out, mode=mode)
f()
def test_constant_merge(self):
"""This test the error in gh-1122 that is a caused by the
combination of merge optimizer and ShapeFeature.
"""
x = tensor.constant([0, 0])
y = x[1:]
x1 = x - tensor.join(0, y, y)
x1.eval()
def test_local_track_shape_i(self):
class IdentityNoShape(gof.Op):
'''Op that does not infer the output shape from the input one'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论