提交 d251eb7e authored 作者: David Warde-Farley's avatar David Warde-Farley

PEP8: fix all instances of E302 (2 blank lines)

上级 50387206
......@@ -21,6 +21,7 @@ from theano.tensor import blas
sparse_formats = ['csc', 'csr']
#TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs):
compile.optdb['specialize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags)
......@@ -33,6 +34,7 @@ _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
#* new class ``bsr_matrix`` : the Block CSR format
_mtype_to_str = {scipy.sparse.csc_matrix: "csc", scipy.sparse.csr_matrix: "csr"}
def _is_sparse_variable(x):
"""
@rtype: boolean
......@@ -41,6 +43,8 @@ def _is_sparse_variable(x):
if not isinstance(x.type, (SparseType, tensor.TensorType)):
raise NotImplementedError("this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,", x)
return isinstance(x.type, SparseType)
def _is_dense_variable(x):
"""
@rtype: boolean
......@@ -50,6 +54,7 @@ def _is_dense_variable(x):
raise NotImplementedError("this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,", x)
return isinstance(x.type, tensor.TensorType)
def _is_sparse(x):
"""
@rtype: boolean
......@@ -58,6 +63,8 @@ def _is_sparse(x):
if not isinstance(x, (scipy.sparse.spmatrix, numpy.ndarray)):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, scipy.sparse.spmatrix)
def _is_dense(x):
"""
@rtype: boolean
......@@ -67,18 +74,19 @@ def _is_dense(x):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, numpy.ndarray)
def _kmap_eq(a, b):
if a is None and b is None:
return True
return numpy.all(a == b)
def _kmap_hash(a):
if a is None: return 12345
return hash(numpy.str(a))
# Wrapper type
def as_sparse_variable(x, name=None):
"""
Wrapper around SparseVariable constructor.
......@@ -101,9 +109,9 @@ def as_sparse_variable(x, name=None):
return constant(x, name=name)
except TypeError:
raise TypeError("Cannot convert %s to SparseType" % x, type(x))
as_sparse = as_sparse_variable
as_sparse = as_sparse_variable
def as_sparse_or_tensor_variable(x, name=None):
"""
If we can't make a sparse variable, we try to make a tensor variable.
......@@ -133,6 +141,7 @@ if 0:
except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x):
data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats
return CSM(format=x.format)(tensor.ones_like(data), indices, indptr, shape)
......@@ -213,6 +222,7 @@ class SparseVariable(gof.Variable, _sparse_py_operators):
def __repr__(self):
return str(self)
class SparseConstantSignature(tuple):
def __eq__(self, other):
(a, b), (x,y) = self, other
......@@ -225,6 +235,7 @@ class SparseConstantSignature(tuple):
(a,b) = self
return hash(type(self)) ^ hash(a) ^ hash(type(b))
class SparseConstant(gof.Constant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
......@@ -242,10 +253,12 @@ class SparseConstant(gof.Constant, _sparse_py_operators):
def __repr__(self):
return str(self)
class SparseValue(gof.Value, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseType(gof.Type):
"""
@type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
......@@ -366,8 +379,12 @@ def matrix(format, name=None, dtype=None):
dtype = config.floatX
type = SparseType(format=format, dtype=dtype)
return type(name)
def csc_matrix(name=None, dtype=None):
return matrix('csc', name, dtype)
def csr_matrix(name=None, dtype=None):
return matrix('csr', name, dtype)
# for more dtypes, call SparseType(format, dtype)
......@@ -378,6 +395,7 @@ csr_dmatrix = SparseType(format='csr', dtype='float64')
csc_fmatrix = SparseType(format='csc', dtype='float32')
csr_fmatrix = SparseType(format='csr', dtype='float32')
# CONSTRUCTION
class CSMProperties(gof.Op):
"""Extract all of .data .indices and .indptr"""
......@@ -427,11 +445,20 @@ class CSMProperties(gof.Op):
else:
return [CSR('csm')(g_data, indices, indptr, shape)]
csm_properties = CSMProperties() #don't make this a function or it breaks some optimizations below
def csm_data(csm): return csm_properties(csm)[0]
def csm_indices(csm): return csm_properties(csm)[1]
def csm_indptr(csm): return csm_properties(csm)[2]
def csm_shape(csm): return csm_properties(csm)[3]
class CSM(gof.Op):
"""Construct a CSC or CSR matrix from the internal representation """
view_map = {0:[0]} #should view the other inputs too, but viewing multiple inputs is not
......@@ -536,6 +563,7 @@ class CSM(gof.Op):
CSC = CSM('csc')
CSR = CSM('csr')
class CSMGrad(gof.op.Op):
def __init__(self, kmap=None):
self.kmap = kmap
......@@ -563,6 +591,7 @@ class CSMGrad(gof.op.Op):
g_data[0] = grad
csm_grad = CSMGrad
@gof.local_optimizer([csm_properties])
def skip_pack_csc01(node):
"""if we find csm_properties(CSM(*args)), then we can replace that with the *args
......@@ -580,7 +609,6 @@ def skip_pack_csc01(node):
register_specialize(skip_pack_csc01)
#
# Conversion
#
......@@ -617,6 +645,7 @@ class DenseFromSparse(gof.op.Op):
return [ishape]
dense_from_sparse = DenseFromSparse()
class SparseFromDense(gof.op.Op):
def __init__(self, format):
self.format = format
......@@ -817,6 +846,7 @@ class Transpose(gof.op.Op):
return transpose(gz),
transpose = Transpose()
class Neg(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
......@@ -833,6 +863,7 @@ class Neg(gof.op.Op):
return -gz,
neg = Neg()
class AddSS(gof.op.Op):
'''Add two sparse matrices '''
def __eq__(self, other):
......@@ -858,6 +889,8 @@ class AddSS(gof.op.Op):
assert _is_sparse_variable(gz)
return gz, gz
add_s_s = AddSS()
class AddSD(gof.op.Op):
''' Add a sparse and a dense matrix '''
def __eq__(self, other):
......@@ -885,6 +918,8 @@ class AddSD(gof.op.Op):
assert _is_dense_variable(gz)
return sp_ones_like(x) * gz, gz
add_s_d = AddSD()
def add(x,y):
"""
Add two matrices, at least one of which is sparse.
......@@ -900,11 +935,12 @@ def add(x,y):
elif x_is_sparse_variable and not y_is_sparse_variable: return add_s_d(x,y)
elif y_is_sparse_variable and not x_is_sparse_variable: return add_s_d(y,x)
else: raise NotImplementedError()
def sub(x,y):
return x + (-y)
class MulSS(gof.op.Op):
''' Elementwise multiply a sparse and a sparse '''
def __eq__(self, other):
......@@ -928,6 +964,8 @@ class MulSS(gof.op.Op):
def grad(self, (x, y), (gz,)):
return y * gz, x * gz
mul_s_s = MulSS()
class MulSD(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
......@@ -995,6 +1033,8 @@ class MulSD(gof.op.Op):
assert _is_sparse_variable(gz)
return y * gz, x * gz
mul_s_d = MulSD()
def mul(x,y):
"""
Multiply (elementwise) two matrices, at least one of which is sparse.
......@@ -1011,6 +1051,7 @@ def mul(x,y):
elif y_is_sparse_variable and not x_is_sparse_variable: return mul_s_d(y,x)
else: raise NotImplementedError()
###############
#
# StructuredDot
......@@ -1073,9 +1114,9 @@ class StructuredDot(gof.Op):
# ga = g_out x b.T
# gb = a.T x g_out
return [structured_dot_grad(a, b, g_out), structured_dot(a.T,g_out)]
_structured_dot = StructuredDot()
def structured_dot(x, y):
"""
@todo: Maybe the triple-transposition formulation (when x is dense)
......@@ -1096,6 +1137,7 @@ def structured_dot(x, y):
assert y_is_sparse_variable
return _structured_dot(y.T, x.T).T
class StructuredDotCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
......@@ -1262,6 +1304,7 @@ class StructuredDotCSC(gof.Op):
return (2,)
sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
......@@ -1394,6 +1437,7 @@ class StructuredDotCSR(gof.Op):
return (1,)
sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx
@gof.local_optimizer([_structured_dot])
def local_structured_dot(node):
......@@ -1414,6 +1458,7 @@ def local_structured_dot(node):
# involved. dimension mismatches are hard to detect sensibly.
#register_specialize(local_structured_dot)
def structured_dot_grad(sparse_A, dense_B, ga):
if sparse_A.type.format in ('csc','csr'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论