提交 8867a720 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix handling of numpy.dtype input in TensorType

上级 813d0409
...@@ -92,7 +92,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -92,7 +92,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
) )
shape = broadcastable shape = broadcastable
if dtype == "floatX": if str(dtype) == "floatX":
self.dtype = config.floatX self.dtype = config.floatX
else: else:
if np.obj2sctype(dtype) is None: if np.obj2sctype(dtype) is None:
......
...@@ -10,9 +10,18 @@ from aesara.tensor.shape import SpecifyShape ...@@ -10,9 +10,18 @@ from aesara.tensor.shape import SpecifyShape
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
def test_numpy_dtype(): @pytest.mark.parametrize(
test_type = TensorType(np.int32, []) "dtype, exp_dtype",
assert test_type.dtype == "int32" [
(np.int32, "int32"),
(np.dtype(np.int32), "int32"),
("int32", "int32"),
("floatX", config.floatX),
],
)
def test_numpy_dtype(dtype, exp_dtype):
test_type = TensorType(dtype, [])
assert test_type.dtype == exp_dtype
def test_in_same_class(): def test_in_same_class():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论