提交 f16f8559 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix TensorType dtype validation issues and refactor tests

上级 12765623
...@@ -92,9 +92,12 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -92,9 +92,12 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
) )
shape = broadcastable shape = broadcastable
if isinstance(dtype, str) and dtype == "floatX": if dtype == "floatX":
self.dtype = config.floatX self.dtype = config.floatX
else: else:
if np.obj2sctype(dtype) is None:
raise TypeError(f"Invalid dtype: {dtype}")
self.dtype = np.dtype(dtype).name self.dtype = np.dtype(dtype).name
def parse_bcast_and_shape(s): def parse_bcast_and_shape(s):
......
...@@ -1069,8 +1069,9 @@ class TestCast: ...@@ -1069,8 +1069,9 @@ class TestCast:
f = function([x], y) f = function([x], y)
assert f(np.array([1, 2], dtype=np.int32)).dtype == np.int64 assert f(np.array([1, 2], dtype=np.int32)).dtype == np.int64
def test_good_between_real_types(self): @pytest.mark.parametrize(
good = itertools.chain( "test_name, obj_dtype",
itertools.chain(
multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES), multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES),
# Casts from foo to foo # Casts from foo to foo
[ [
...@@ -1080,8 +1081,10 @@ class TestCast: ...@@ -1080,8 +1081,10 @@ class TestCast:
) )
for dtype in ALL_DTYPES for dtype in ALL_DTYPES
], ],
),
) )
for testname, (obj, dtype) in good: def test_good_between_real_types(self, test_name, obj_dtype):
(obj, dtype) = obj_dtype
inp = vector(dtype=obj.dtype) inp = vector(dtype=obj.dtype)
out = cast(inp, dtype=dtype) out = cast(inp, dtype=dtype)
f = function([inp], out) f = function([inp], out)
...@@ -1091,21 +1094,21 @@ class TestCast: ...@@ -1091,21 +1094,21 @@ class TestCast:
out2 = inp.astype(dtype=dtype) out2 = inp.astype(dtype=dtype)
assert out2.type == out.type assert out2.type == out.type
def test_cast_from_real_to_complex(self): @pytest.mark.parametrize("real_dtype", REAL_DTYPES)
for real_dtype in REAL_DTYPES: @pytest.mark.parametrize("complex_dtype", COMPLEX_DTYPES)
for complex_dtype in COMPLEX_DTYPES: def test_cast_from_real_to_complex(self, real_dtype, complex_dtype):
inp = vector(dtype=real_dtype) inp = vector(dtype=real_dtype)
out = cast(inp, dtype=complex_dtype) out = cast(inp, dtype=complex_dtype)
f = function([inp], out) f = function([inp], out)
obj = random_of_dtype((2,), real_dtype) obj = random_of_dtype((2,), real_dtype)
assert f(obj).dtype == np.dtype(complex_dtype) assert f(obj).dtype == np.dtype(complex_dtype)
def test_cast_from_complex_to_real_raises_error(self): @pytest.mark.parametrize("real_dtype", REAL_DTYPES)
for real_dtype in REAL_DTYPES: @pytest.mark.parametrize("complex_dtype", COMPLEX_DTYPES)
for complex_dtype in COMPLEX_DTYPES: def test_cast_from_complex_to_real_raises_error(self, real_dtype, complex_dtype):
inp = vector(dtype=real_dtype) inp = vector(dtype=complex_dtype)
with pytest.raises(TypeError): with pytest.raises(TypeError):
tensor(cast(inp, dtype=complex_dtype)) cast(inp, dtype=real_dtype)
# TODO: consider moving this function / functionality to gradient.py # TODO: consider moving this function / functionality to gradient.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论