提交 b9f3daec authored 作者: Frederic Bastien's avatar Frederic Bastien

moved 2 functions around.

上级 a7453163
......@@ -69,3 +69,27 @@ else:
partial = functools.partial
defaultdict = collections.defaultdict
__all__ = ['all', 'any']
if sys.version_info[:2] < (2,6):
# Borrowed from Python docs
def combinations(iterable, r):
# combinations('ABCD', 2) --> AB AC AD BC BD CD
# combinations(range(4), 3) --> 012 013 023 123
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = range(r)
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
else:
from itertools import combinations
......@@ -15,7 +15,7 @@ from theano.tensor import inplace
from copy import copy
from theano import compile, config
from theano import gof
from theano.gof.python25 import any, all
from theano.gof.python25 import any, all, combinations
from theano.compile.mode import get_default_mode
from theano import function
......@@ -218,6 +218,16 @@ def randint_ranged(min, max, shape):
def randc128_ranged(min, max, shape):
return numpy.asarray(numpy.random.rand(*shape) * (max - min) + min, dtype='complex128')
def rand_of_dtype(shape, dtype):
if 'int' in dtype:
return randint(*shape).astype(dtype)
elif 'float' in dtype:
return rand(*shape).astype(dtype)
elif 'complex' in dtype:
return randcomplex(*shape).astype(dtype)
else:
raise TypeError()
def makeBroadcastTester(op, expected, checks = {}, **kwargs):
name = str(op) + "Tester"
if kwargs.has_key('inplace'):
......@@ -792,43 +802,11 @@ DotTester = makeTester(name = 'DotTester',
def _numpy_second(x, y):
return numpy.broadcast_arrays(x, y)[1]
def combinations(iterable, r):
# Borrowed from Python docs - can be removed when we drop
# support for 2.4/2.5
# combinations('ABCD', 2) --> AB AC AD BC BD CD
# combinations(range(4), 3) --> 012 013 023 123
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = range(r)
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
ALL_DTYPES = ('int8', 'int16', 'int32', 'int64',
'float32', 'float64', 'complex64', 'complex128')
REAL_DTYPES = ALL_DTYPES[:-2]
COMPLEX_DTYPES = ALL_DTYPES[-2:]
def rand_of_dtype(shape, dtype):
if 'int' in dtype:
return randint(*shape).astype(dtype)
elif 'float' in dtype:
return rand(*shape).astype(dtype)
elif 'complex' in dtype:
return randcomplex(*shape).astype(dtype)
else:
raise TypeError()
def multi_dtype_checks(shape1, shape2, dtypes=ALL_DTYPES, nameprefix=''):
for dtype1, dtype2 in combinations(dtypes, 2):
name1 = '%s_%s_%s' % (nameprefix, dtype1, dtype2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论