提交 db14207b authored 作者: David Warde-Farley's avatar David Warde-Farley

Tests for second(), and some support functions to make that happen.

上级 151a5192
import itertools
import operator
import StringIO
import sys
......@@ -819,6 +820,88 @@ ClipTester = makeTester(name='ClipTester',
# I can't think of any way to make this fail at runtime
bad_runtime=dict())
def _numpy_second(x, y):
if x.ndim != y.ndim:
return broadcast_arrays(x, y)[1]
else:
return y
def combinations(iterable, r):
# Borrowed from Python docs - can be removed when we drop
# support for 2.4/2.5
# combinations('ABCD', 2) --> AB AC AD BC BD CD
# combinations(range(4), 3) --> 012 013 023 123
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = range(r)
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
ALL_DTYPES = ('int8', 'int16', 'int32', 'int64',
'float32', 'float64', 'complex64', 'complex128')
def rand_of_dtype(shape, dtype):
if 'int' in dtype:
return randint(shape).astype(dtype)
elif 'float' in dtype:
return rand(shape).astype(dtype)
elif 'complex' in dtype:
return randcomplex(shape).astype(dtype)
else:
raise TypeError()
def multi_dtype_tests(shape1, shape2, dtypes=ALL_DTYPES, nameprefix=''):
for dtype1, dtype2 in combinations(dtypes, 2):
name1 = '%s_%s' % (nameprefix, dtype1, dtype2)
name2 = '%s_%s' % (nameprefix, dtype2, dtype1)
obj1 = rand_of_dtype(shape1, dtype1)
obj2 = rand_of_dtype(shape2, dtype2)
yield (name, (obj1, obj2))
yield (name, (obj2, obj1))
SecondBroadcastTester = makeTester(
name='SecondBroadcastTester',
op=second,
expected=_numpy_second,
good=dict(itertools.chain(
multi_dtype_tests((4, 5), (5,)),
multi_dtype_tests((2, 3, 2), (3, 2)),
multi_dtype_tests((2, 3, 2), (2,)),
)),
# I can't think of any way to make this fail at
# build time
bad_build=None,
# Just some simple smoke tests
bad_runtime=dict(
fail1=(rand(5, 4), rand(5)),
fail2=(rand(3, 2, 3), rand(6, 9)),
fail3=(randint(6, 2), rand(3, 2)),
)
)
SecondSameRankTester = makeTester(
name='SecondSameRankTester',
op=second,
expected=_numpy_second,
good=dict(itertools.chain(
multi_dtype_tests((4, 5), (4, 5)),
multi_dtype_tests((5, 4), (4, 5)),
multi_dtype_tests((1, 4), (3, 2)),
)),
bad_build=None,
bad_runtime=None
)
#TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论