提交 854b8743 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

merged

import itertools
import operator
import StringIO
import sys
......@@ -794,7 +795,113 @@ DotTester = makeTester(name = 'DotTester',
bad_runtime = dict(bad1 = (rand(5, 7), rand(5, 7)),
bad2 = (rand(5, 7), rand(8, 3))))
ClipTester = makeTester(name='ClipTester',
op=clip,
expected=lambda x, y, z: numpy.clip(x, y, z),
good = dict(correct1=((5 * rand(5, 5)).astype('float32'),
-1, 1),
correct2=((5 * rand(5, 5)).astype('float64'),
-1, 1),
correct3=(randint(5, 5).astype('int8'),
-1, 1),
correct4=(randint(5, 5).astype('int16'),
-1, 1),
correct4=(randint(5, 5).astype('int32'),
-1, 1),
correct5=(randint(5, 5).astype('int64'),
-1, 1)),
# These don't build -- is this equivalent to marking
# them as 'known fail'?
bad_build=dict(
bad1=(randcomplex(5, 5).astype('complex64'),
-1, 1),
bad2=(randcomplex(5, 5).astype('complex128'),
-1, 1)),
# I can't think of any way to make this fail at runtime
bad_runtime=dict())
def _numpy_second(x, y):
if x.ndim != y.ndim:
return broadcast_arrays(x, y)[1]
else:
return y
def combinations(iterable, r):
# Borrowed from Python docs - can be removed when we drop
# support for 2.4/2.5
# combinations('ABCD', 2) --> AB AC AD BC BD CD
# combinations(range(4), 3) --> 012 013 023 123
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = range(r)
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
ALL_DTYPES = ('int8', 'int16', 'int32', 'int64',
'float32', 'float64', 'complex64', 'complex128')
def rand_of_dtype(shape, dtype):
if 'int' in dtype:
return randint(shape).astype(dtype)
elif 'float' in dtype:
return rand(shape).astype(dtype)
elif 'complex' in dtype:
return randcomplex(shape).astype(dtype)
else:
raise TypeError()
def multi_dtype_tests(shape1, shape2, dtypes=ALL_DTYPES, nameprefix=''):
for dtype1, dtype2 in combinations(dtypes, 2):
name1 = '%s_%s' % (nameprefix, dtype1, dtype2)
name2 = '%s_%s' % (nameprefix, dtype2, dtype1)
obj1 = rand_of_dtype(shape1, dtype1)
obj2 = rand_of_dtype(shape2, dtype2)
yield (name, (obj1, obj2))
yield (name, (obj2, obj1))
SecondBroadcastTester = makeTester(
name='SecondBroadcastTester',
op=second,
expected=_numpy_second,
good=dict(itertools.chain(
multi_dtype_tests((4, 5), (5,)),
multi_dtype_tests((2, 3, 2), (3, 2)),
multi_dtype_tests((2, 3, 2), (2,)),
)),
# I can't think of any way to make this fail at
# build time
bad_build=None,
# Just some simple smoke tests
bad_runtime=dict(
fail1=(rand(5, 4), rand(5)),
fail2=(rand(3, 2, 3), rand(6, 9)),
fail3=(randint(6, 2), rand(3, 2)),
)
)
SecondSameRankTester = makeTester(
name='SecondSameRankTester',
op=second,
expected=_numpy_second,
good=dict(itertools.chain(
multi_dtype_tests((4, 5), (4, 5)),
multi_dtype_tests((5, 4), (4, 5)),
multi_dtype_tests((1, 4), (3, 2)),
)),
bad_build=None,
bad_runtime=None
)
#TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论