提交 1e52f7ea authored 作者: David Warde-Farley's avatar David Warde-Farley

Port forward my cast tester that I actually quashed in a merge.

上级 1d827d8e
......@@ -819,6 +819,8 @@ def combinations(iterable, r):
ALL_DTYPES = ('int8', 'int16', 'int32', 'int64',
'float32', 'float64', 'complex64', 'complex128')
REAL_DTYPES = ALL_DTYPES[:-2]
COMPLEX_DTYPES = ALL_DTYPES[-2:]
def rand_of_dtype(shape, dtype):
if 'int' in dtype:
......@@ -882,18 +884,35 @@ SecondSameRankTester = makeTester(
))
)
CastTester = makeTester(
name='CastTester',
op=cast,
expected=lambda x, y: x.astype(y),
good=dict(itertools.chain(
multi_dtype_cast_checks((2,)),
class CastTester(unittest.TestCase):
def test_good_between_real_types(self):
expected = lambda x, y: x.astype(y),
good = itertools.chain(
multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES),
# Casts from foo to foo
[('%s_%s' % (rand_of_dtype((2,), dtype), dtype),
(rand_of_dtype((2,), dtype), dtype))
for dtype in ALL_DTYPES]
)),
)
for dtype in ALL_DTYPES])
for testname, (obj, dtype) in good:
inp = tensor.vector(dtype=obj.dtype)
out = tensor.cast(inp, dtype=dtype)
f = function([inp], out)
assert f(obj).dtype == numpy.dtype(dtype)
def test_cast_from_real_to_complex(self):
for real_dtype in REAL_DTYPES:
for complex_dtype in COMPLEX_DTYPES:
inp = tensor.vector(dtype=real_dtype)
out = tensor.cast(inp, dtype=complex_dtype)
f = function([inp], out)
obj = rand_of_dtype((2, ), real_dtype)
assert f(obj).dtype == numpy.dtype(complex_dtype)
def test_cast_from_complex_to_real_raises_error(self):
for real_dtype in REAL_DTYPES:
for complex_dtype in COMPLEX_DTYPES:
inp = tensor.vector(dtype=real_dtype)
self.assertRaises(TypeError, tensor.cast(inp, dtype=complex_dtype))
ClipTester = makeTester(name='ClipTester',
op=clip,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论