提交 bdcad1da authored 作者: Frederic Bastien's avatar Frederic Bastien

fix a bug in theano.get_constant_value and add test for it.

上级 43c6f89e
...@@ -372,7 +372,7 @@ def get_constant_value(v): ...@@ -372,7 +372,7 @@ def get_constant_value(v):
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, Subtensor) and v.ndim==0: if isinstance(v.owner.op, Subtensor) and v.ndim==0:
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]] return v.owner.inputs[0].data.__getitem__(tuple(v.owner.op.idx_list))
# The index list 'idx_list' should have length one # The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector. # since joining scalar variables results in a 1D vector.
......
...@@ -3575,6 +3575,15 @@ class T_get_constant_value(unittest.TestCase): ...@@ -3575,6 +3575,15 @@ class T_get_constant_value(unittest.TestCase):
v = tensor.row() v = tensor.row()
assert get_constant_value(v.shape[0])==1 assert get_constant_value(v.shape[0])==1
def test_subtensor_of_constant(self):
c = constant(numpy.random.rand(5))
for i in range(c.value.shape[0]):
assert get_constant_value(c[i]) == c.value[i]
c = constant(numpy.random.rand(5,5))
for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]):
assert get_constant_value(c[i,j]) == c.value[i,j]
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论