提交 d8a60f88 authored 作者: Frederic's avatar Frederic

Fix opt error when trying to access a python list with a Scalar Type object.

fixes gh-1288
上级 dc5617c1
...@@ -609,6 +609,7 @@ def get_scalar_constant_value(v): ...@@ -609,6 +609,7 @@ def get_scalar_constant_value(v):
ret = get_scalar_constant_value(ret) ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case. # join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and if (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, isinstance(v.owner.inputs[0].owner.op,
theano.tensor.opt.MakeVector) and theano.tensor.opt.MakeVector) and
...@@ -616,8 +617,9 @@ def get_scalar_constant_value(v): ...@@ -616,8 +617,9 @@ def get_scalar_constant_value(v):
# We put this check in case there is change in the future # We put this check in case there is change in the future
python_all(var.ndim == 0 for var in python_all(var.ndim == 0 for var in
v.owner.inputs[0].owner.inputs) and v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1): len(v.owner.op.idx_list) == 1 and
#idx_list can contain Scalar Type object.
isinstance(v.owner.op.idx_list[0], (int, long))):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_scalar_constant_value(ret) ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
......
...@@ -6664,6 +6664,21 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -6664,6 +6664,21 @@ class T_get_scalar_constant_value(unittest.TestCase):
get_scalar_constant_value, get_scalar_constant_value,
numpy.array([])) numpy.array([]))
def test_make_vector(self):
mv = opt.make_vector(1, 2, 3)
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value,
mv)
assert get_scalar_constant_value(mv[0])
assert get_scalar_constant_value(mv[1])
assert get_scalar_constant_value(mv[2])
t = theano.scalar.Scalar('int64')
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value,
mv[t()])
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论