提交 36e04bca authored 作者: David Warde-Farley's avatar David Warde-Farley

Tests for cast.

上级 3f03e996
......@@ -870,6 +870,14 @@ def multi_dtype_tests(shape1, shape2, dtypes=ALL_DTYPES, nameprefix=''):
yield (name, (obj1, obj2))
yield (name, (obj2, obj1))
def multi_dtype_cast_tests(shape, dtypes=ALL_DTYPES, nameprefix=''):
for dtype1, dtype2 in combinations(dtypes, 2):
name1 = '%s_%s_%s' % (nameprefix, dtype1, dtype2)
name2 = '%s_%s_%s' % (nameprefix, dtype2, dtype1)
obj1 = rand_of_dtype(shape, dtype1)
yield (name, (obj1, dtype2))
yield (name, (obj2, dtype1))
SecondBroadcastTester = makeTester(
name='SecondBroadcastTester',
op=second,
......@@ -903,6 +911,22 @@ SecondSameRankTester = makeTester(
bad_runtime=None
)
CastTester = makeTester(
name='CastTester',
op=cast,
expected=lambda x, y: x.astype(y),
good=dict(itertools.chain(
multi_dtype_cast_tests((2,)),
[('%s_%s' % (dtype, dtype), rand_of_dtype((2,), dtype), dtype)
for dtype in ALL_DTYPES]
)),
bad_build=dict(
fail_not_a_real_dtype=((2,), 'blah')
),
bad_runtime=None
)
#TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify
# gradient numerically
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论