提交 bdcad1da authored 作者: Frederic Bastien's avatar Frederic Bastien

fix a bug in theano.get_constant_value and add test for it.

上级 43c6f89e
......@@ -372,7 +372,7 @@ def get_constant_value(v):
return ret[0][0]
if isinstance(v.owner.op, Subtensor) and v.ndim==0:
if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]]
return v.owner.inputs[0].data.__getitem__(tuple(v.owner.op.idx_list))
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
......
......@@ -3575,6 +3575,15 @@ class T_get_constant_value(unittest.TestCase):
v = tensor.row()
assert get_constant_value(v.shape[0])==1
def test_subtensor_of_constant(self):
c = constant(numpy.random.rand(5))
for i in range(c.value.shape[0]):
assert get_constant_value(c[i]) == c.value[i]
c = constant(numpy.random.rand(5,5))
for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]):
assert get_constant_value(c[i,j]) == c.value[i,j]
if __name__ == '__main__':
if 1:
unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论