提交 d1c3a6bb authored 作者: David Warde-Farley's avatar David Warde-Farley

Cleaned up crazy nested loops, style fixes.

itertools.product(), or theano.gof.python25.product() in this case. Also PEP8 fixes.
上级 ae17e382
...@@ -91,5 +91,17 @@ if sys.version_info[:2] < (2,6): ...@@ -91,5 +91,17 @@ if sys.version_info[:2] < (2,6):
for j in range(i+1, r): for j in range(i+1, r):
indices[j] = indices[j-1] + 1 indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
def product(*args, **kwds):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = map(tuple, args) * kwds.get('repeat', 1)
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
else: else:
from itertools import combinations from itertools import combinations, product
...@@ -12,6 +12,8 @@ except ImportError: ...@@ -12,6 +12,8 @@ except ImportError:
import theano import theano
from theano import compile, config from theano import compile, config
from theano.sparse import enable_sparse from theano.sparse import enable_sparse
from theano.gof.python25 import product
if enable_sparse == False: if enable_sparse == False:
raise SkipTest('Optional package sparse disabled') raise SkipTest('Optional package sparse disabled')
...@@ -592,69 +594,80 @@ class UsmmTests(unittest.TestCase): ...@@ -592,69 +594,80 @@ class UsmmTests(unittest.TestCase):
else: else:
return theano.sparse.matrix(format, name, dtype=dtype) return theano.sparse.matrix(format, name, dtype=dtype)
for dtype1 in ['float32', 'float64']: params = product(*([['float32', 'float64']] * 4 +
for dtype2 in ['float32', 'float64']: [['dense', 'csc', 'csr']] * 2))
for dtype3 in ['float32', 'float64']:
for dtype4 in ['float32', 'float64']: for dtype1, dtype2, dtype3, dtype4, format1, format2 in params:
for format1 in ['dense', 'csc','csr']: if format1 == 'dense' and format2 == 'dense':
for format2 in ['dense', 'csc','csr']: # Usmm won't be used!
if format1 == 'dense' and format2 == 'dense': continue
# Usmm won't be used! x = mat(format1, 'x', dtype1)
continue y = mat(format2, 'y', dtype2)
x = mat(format1, 'x', dtype1) a = theano.tensor.scalar('a', dtype=dtype3)
y = mat(format2, 'y', dtype2) z = theano.tensor.shared(
a = theano.tensor.scalar('a', dtype=dtype3) numpy.asarray(self.z, dtype=dtype4).copy()
z = theano.tensor.shared(numpy.asarray(self.z,dtype=dtype4).copy()) )
f_b = lambda z, a, x, y: z - a * (x * y) f_b = lambda z, a, x, y: z - a * (x * y)
x_data = numpy.asarray(self.x, dtype = dtype1) x_data = numpy.asarray(self.x, dtype=dtype1)
if format1 != 'dense': if format1 != 'dense':
x_data = as_sparse_format(x_data, format1) x_data = as_sparse_format(x_data, format1)
y_data = numpy.asarray(self.y, dtype = dtype2) y_data = numpy.asarray(self.y, dtype=dtype2)
if format2 != 'dense': if format2 != 'dense':
y_data = as_sparse_format(y_data, format2) y_data = as_sparse_format(y_data, format2)
z_data = numpy.asarray(self.z, dtype = dtype3) z_data = numpy.asarray(self.z, dtype=dtype3)
f_b_out = f_b(z_data, 1, x_data, y_data) f_b_out = f_b(z_data, 1, x_data, y_data)
# Can it work inplace? # Can it work inplace?
inplace = dtype4 == theano.scalar.upcast(dtype1, dtype2, dtype3) inplace = dtype4 == theano.scalar.upcast(dtype1, dtype2, dtype3)
# To make it easier to check the toposort # To make it easier to check the toposort
mode = theano.compile.mode.get_default_mode().excluding('fusion') mode = theano.compile.mode.get_default_mode().excluding('fusion')
if inplace: if inplace:
f_a = theano.function([a, x, y], [], updates = {z: z - a * theano.sparse.dot(x, y)}
updates={ z : z - a * theano.sparse.dot(x, y)}, f_a = theano.function([a, x, y], [],
mode = mode) updates=updates,
f_a(1, x_data, y_data) mode=mode)
assert abs(z.get_value(borrow=True) - f_b_out).max() < 1e-4 f_a(1, x_data, y_data)
else: assert abs(z.get_value(borrow=True) - f_b_out).max() < 1e-4
f_a = theano.function([a, x, y], z - a * theano.sparse.dot(x, y), else:
mode = mode) f_a = theano.function([a, x, y],
f_a_out = f_a(1, x_data, y_data) z - a * theano.sparse.dot(x, y),
assert abs(f_a_out - f_b_out).max() < 1e-4 mode=mode)
topo = f_a.maker.env.toposort() f_a_out = f_a(1, x_data, y_data)
if (y.type.dtype == theano.scalar.upcast(dtype1, dtype2, dtype3, dtype4) assert abs(f_a_out - f_b_out).max() < 1e-4
and format1=='csc' and format2=='dense'): topo = f_a.maker.env.toposort()
up = theano.scalar.upcast(dtype1, dtype2, dtype3, dtype4)
assert sum([isinstance(node.op, tensor.Elemwise) and isinstance(node.op.scalar_op, theano.scalar.basic.Cast) for node in topo])==len(topo)-5 if y.type.dtype == up and format1 == 'csc' and format2 == 'dense':
topo = [node for node in topo if not(isinstance(node.op, tensor.Elemwise) and isinstance(node.op.scalar_op, theano.scalar.basic.Cast))] assert (sum([isinstance(node.op, tensor.Elemwise) and
assert len(topo)==5, topo isinstance(node.op.scalar_op,
# Usmm is tested at the same time in debugmode theano.scalar.basic.Cast)
# Check if the optimization local_usmm and local_usmm_csx is applied for node in topo]) == len(topo) - 5)
assert isinstance(topo[0].op, theano.sparse.basic.CSMProperties) new_topo = []
assert isinstance(topo[1].op, theano.tensor.DimShuffle) for node in topo:
assert isinstance(topo[2].op, theano.tensor.Subtensor) if not isinstance(node.op, tensor.Elemwise) and \
assert topo[3].op == theano.tensor.neg isinstance(node.op.scalar_op, theano.scalar.basic.Cast):
assert isinstance(topo[4].op, theano.sparse.UsmmCscDense) new_topo.append(node)
if inplace: topo = new_topo
assert topo[4].op.inplace assert len(topo) == 5, topo
else: # Usmm is tested at the same time in debugmode
assert len(topo)==3, topo # Check if the optimization local_usmm and local_usmm_csx is
assert isinstance(topo[0].op, theano.tensor.DimShuffle) # applied
assert topo[1].op == theano.tensor.neg assert isinstance(topo[0].op,
assert isinstance(topo[2].op, theano.sparse.Usmm) theano.sparse.basic.CSMProperties)
assert isinstance(topo[1].op, theano.tensor.DimShuffle)
assert isinstance(topo[2].op, theano.tensor.Subtensor)
assert topo[3].op == theano.tensor.neg
assert isinstance(topo[4].op, theano.sparse.UsmmCscDense)
if inplace:
assert topo[4].op.inplace
else:
assert len(topo)==3, topo
assert isinstance(topo[0].op, theano.tensor.DimShuffle)
assert topo[1].op == theano.tensor.neg
assert isinstance(topo[2].op, theano.sparse.Usmm)
def test_shape_i(): def test_shape_i():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论