提交 7bb18f3a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numpy integer check in TensorType

Bug introduced in 6834740a
上级 6834740a
...@@ -109,7 +109,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -109,7 +109,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def parse_bcast_and_shape(s): def parse_bcast_and_shape(s):
if isinstance(s, (bool, np.bool_)): if isinstance(s, (bool, np.bool_)):
return 1 if s else None return 1 if s else None
elif isinstance(s, (int, np.int_)): elif isinstance(s, (int, np.integer)):
return int(s) return int(s)
elif s is None: elif s is None:
return s return s
......
...@@ -274,6 +274,12 @@ def test_shape_type_conversion(): ...@@ -274,6 +274,12 @@ def test_shape_type_conversion():
assert t1.broadcastable == (False,) assert t1.broadcastable == (False,)
assert isinstance(t1.broadcastable[0], bool) assert isinstance(t1.broadcastable[0], bool)
t1 = TensorType("float64", shape=np.array([3], dtype=np.int32))
assert t1.shape == (3,)
assert isinstance(t1.shape[0], int)
assert t1.broadcastable == (False,)
assert isinstance(t1.broadcastable[0], bool)
t2 = TensorType("float64", broadcastable=np.array([True, False], dtype="bool")) t2 = TensorType("float64", broadcastable=np.array([True, False], dtype="bool"))
assert t2.shape == (1, None) assert t2.shape == (1, None)
assert isinstance(t2.shape[0], int) assert isinstance(t2.shape[0], int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论