提交 23d1b9e3 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge.

......@@ -806,9 +806,9 @@ ClipTester = makeTester(name='ClipTester',
-1, 1),
correct4=(randint(5, 5).astype('int16'),
-1, 1),
correct4=(randint(5, 5).astype('int32'),
correct5=(randint(5, 5).astype('int32'),
-1, 1),
correct5=(randint(5, 5).astype('int64'),
correct6=(randint(5, 5).astype('int64'),
-1, 1)),
# These don't build -- is this equivalent to marking
# them as 'known fail'?
......@@ -853,23 +853,31 @@ ALL_DTYPES = ('int8', 'int16', 'int32', 'int64',
def rand_of_dtype(shape, dtype):
if 'int' in dtype:
return randint(shape).astype(dtype)
return randint(*shape).astype(dtype)
elif 'float' in dtype:
return rand(shape).astype(dtype)
return rand(*shape).astype(dtype)
elif 'complex' in dtype:
return randcomplex(shape).astype(dtype)
return randcomplex(*shape).astype(dtype)
else:
raise TypeError()
def multi_dtype_tests(shape1, shape2, dtypes=ALL_DTYPES, nameprefix=''):
for dtype1, dtype2 in combinations(dtypes, 2):
name1 = '%s_%s' % (nameprefix, dtype1, dtype2)
name2 = '%s_%s' % (nameprefix, dtype2, dtype1)
name1 = '%s_%s_%s' % (nameprefix, dtype1, dtype2)
name2 = '%s_%s_%s' % (nameprefix, dtype2, dtype1)
obj1 = rand_of_dtype(shape1, dtype1)
obj2 = rand_of_dtype(shape2, dtype2)
yield (name, (obj1, obj2))
yield (name, (obj2, obj1))
def multi_dtype_cast_tests(shape, dtypes=ALL_DTYPES, nameprefix=''):
for dtype1, dtype2 in combinations(dtypes, 2):
name1 = '%s_%s_%s' % (nameprefix, dtype1, dtype2)
name2 = '%s_%s_%s' % (nameprefix, dtype2, dtype1)
obj1 = rand_of_dtype(shape, dtype1)
yield (name, (obj1, dtype2))
yield (name, (obj2, dtype1))
SecondBroadcastTester = makeTester(
name='SecondBroadcastTester',
op=second,
......@@ -903,6 +911,22 @@ SecondSameRankTester = makeTester(
bad_runtime=None
)
CastTester = makeTester(
name='CastTester',
op=cast,
expected=lambda x, y: x.astype(y),
good=dict(itertools.chain(
multi_dtype_cast_tests((2,)),
[('%s_%s' % (dtype, dtype), rand_of_dtype((2,), dtype), dtype)
for dtype in ALL_DTYPES]
)),
bad_build=dict(
fail_not_a_real_dtype=((2,), 'blah')
),
bad_runtime=None
)
#TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify
# gradient numerically
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论