提交 2093634f authored 作者: Frederic's avatar Frederic

also support numpy.integer

上级 ed199370
...@@ -619,7 +619,8 @@ def get_scalar_constant_value(v): ...@@ -619,7 +619,8 @@ def get_scalar_constant_value(v):
v.owner.inputs[0].owner.inputs) and v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1 and len(v.owner.op.idx_list) == 1 and
#idx_list can contain Scalar Type object. #idx_list can contain Scalar Type object.
isinstance(v.owner.op.idx_list[0], (int, long))): isinstance(v.owner.op.idx_list[0], (int, long,
numpy.integer))):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_scalar_constant_value(ret) ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
......
...@@ -6736,6 +6736,9 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -6736,6 +6736,9 @@ class T_get_scalar_constant_value(unittest.TestCase):
assert get_scalar_constant_value(mv[0]) == 1 assert get_scalar_constant_value(mv[0]) == 1
assert get_scalar_constant_value(mv[1]) == 2 assert get_scalar_constant_value(mv[1]) == 2
assert get_scalar_constant_value(mv[2]) == 3 assert get_scalar_constant_value(mv[2]) == 3
assert get_scalar_constant_value(mv[numpy.int8(0)]) == 1
assert get_scalar_constant_value(mv[numpy.int64(1)]) == 2
assert get_scalar_constant_value(mv[numpy.uint(2)]) == 3
t = theano.scalar.Scalar('int64') t = theano.scalar.Scalar('int64')
self.assertRaises( self.assertRaises(
tensor.NotScalarConstantError, tensor.NotScalarConstantError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论