提交 8a0ec6fc authored 作者: Frederic's avatar Frederic

Make get_scalar_constant_value() support Shape_i()(Constant)

上级 8d790f5c
...@@ -566,6 +566,9 @@ def get_scalar_constant_value(v): ...@@ -566,6 +566,9 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs x, y = v.owner.inputs
return get_scalar_constant_value(y) return get_scalar_constant_value(y)
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would # fct is used too early in the optimization phase. This would
# mess with the stabilization optimization. # mess with the stabilization optimization.
......
...@@ -5921,6 +5921,13 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -5921,6 +5921,13 @@ class T_get_scalar_constant_value(unittest.TestCase):
get_scalar_constant_value, get_scalar_constant_value,
mv[t()]) mv[t()])
def test_shape_i(self):
c = theano.tensor.constant(numpy.random.rand(3, 4))
s = opt.Shape_i(0)(c)
assert get_scalar_constant_value(s) == 3
s = opt.Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论