提交 5555f7ab authored 作者: David Warde-Farley's avatar David Warde-Farley

Rewrite CastTester as a custom Tester class.

上级 39368a61
...@@ -819,6 +819,8 @@ def combinations(iterable, r): ...@@ -819,6 +819,8 @@ def combinations(iterable, r):
ALL_DTYPES = ('int8', 'int16', 'int32', 'int64', ALL_DTYPES = ('int8', 'int16', 'int32', 'int64',
'float32', 'float64', 'complex64', 'complex128') 'float32', 'float64', 'complex64', 'complex128')
REAL_DTYPES = ALL_DTYPES[:-2]
COMPLEX_DTYPES = ALL_DTYPES[-2:]
def rand_of_dtype(shape, dtype): def rand_of_dtype(shape, dtype):
if 'int' in dtype: if 'int' in dtype:
...@@ -882,18 +884,35 @@ SecondSameRankTester = makeTester( ...@@ -882,18 +884,35 @@ SecondSameRankTester = makeTester(
)) ))
) )
CastTester = makeTester( class CastTester(unittest.TestCase):
name='CastTester', def test_good_between_real_types(self):
op=cast, expected = lambda x, y: x.astype(y),
expected=lambda x, y: x.astype(y), good = itertools.chain(
good=dict(itertools.chain( multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES),
multi_dtype_cast_checks((2,)),
# Casts from foo to foo # Casts from foo to foo
[('%s_%s' % (rand_of_dtype((2,), dtype), dtype), [('%s_%s' % (rand_of_dtype((2,), dtype), dtype),
(rand_of_dtype((2,), dtype), dtype)) (rand_of_dtype((2,), dtype), dtype))
for dtype in ALL_DTYPES] for dtype in ALL_DTYPES])
)), for testname, (obj, dtype) in good:
) inp = tensor.vector(dtype=obj.dtype)
out = tensor.cast(inp, dtype=dtype)
f = function([inp], out)
assert f(obj).dtype == numpy.dtype(dtype)
def test_cast_from_real_to_complex(self):
for real_dtype in REAL_DTYPES:
for complex_dtype in COMPLEX_DTYPES:
inp = tensor.vector(dtype=real_dtype)
out = tensor.cast(inp, dtype=complex_dtype)
f = function([inp], out)
obj = rand_of_dtype((2, ), real_dtype)
assert f(obj).dtype == numpy.dtype(complex_dtype)
def test_cast_from_complex_to_real_raises_error(self):
for real_dtype in REAL_DTYPES:
for complex_dtype in COMPLEX_DTYPES:
inp = tensor.vector(dtype=real_dtype)
self.assertRaises(TypeError, tensor.cast(inp, dtype=complex_dtype))
ClipTester = makeTester(name='ClipTester', ClipTester = makeTester(name='ClipTester',
op=clip, op=clip,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论