提交 d8a60f88 authored 作者: Frederic's avatar Frederic

Fix opt error when trying to access a python list with a Scalar Type object.

fixes gh-1288
上级 dc5617c1
......@@ -609,6 +609,7 @@ def get_scalar_constant_value(v):
ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op,
theano.tensor.opt.MakeVector) and
......@@ -616,8 +617,9 @@ def get_scalar_constant_value(v):
# We put this check in case there is change in the future
python_all(var.ndim == 0 for var in
v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
len(v.owner.op.idx_list) == 1 and
#idx_list can contain Scalar Type object.
isinstance(v.owner.op.idx_list[0], (int, long))):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case.
......
......@@ -6664,6 +6664,21 @@ class T_get_scalar_constant_value(unittest.TestCase):
get_scalar_constant_value,
numpy.array([]))
def test_make_vector(self):
mv = opt.make_vector(1, 2, 3)
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value,
mv)
assert get_scalar_constant_value(mv[0])
assert get_scalar_constant_value(mv[1])
assert get_scalar_constant_value(mv[2])
t = theano.scalar.Scalar('int64')
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value,
mv[t()])
class T_as_tensor_variable(unittest.TestCase):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论