提交 d251eb7e authored 作者: David Warde-Farley's avatar David Warde-Farley

PEP8: fix all instances of E302 (2 blank lines)

上级 50387206
...@@ -21,6 +21,7 @@ from theano.tensor import blas ...@@ -21,6 +21,7 @@ from theano.tensor import blas
sparse_formats = ['csc', 'csr'] sparse_formats = ['csc', 'csr']
#TODO: move this decorator to the compile submodule #TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
compile.optdb['specialize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags) compile.optdb['specialize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags)
...@@ -33,6 +34,7 @@ _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix] ...@@ -33,6 +34,7 @@ _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
#* new class ``bsr_matrix`` : the Block CSR format #* new class ``bsr_matrix`` : the Block CSR format
_mtype_to_str = {scipy.sparse.csc_matrix: "csc", scipy.sparse.csr_matrix: "csr"} _mtype_to_str = {scipy.sparse.csc_matrix: "csc", scipy.sparse.csr_matrix: "csr"}
def _is_sparse_variable(x): def _is_sparse_variable(x):
""" """
@rtype: boolean @rtype: boolean
...@@ -41,6 +43,8 @@ def _is_sparse_variable(x): ...@@ -41,6 +43,8 @@ def _is_sparse_variable(x):
if not isinstance(x.type, (SparseType, tensor.TensorType)): if not isinstance(x.type, (SparseType, tensor.TensorType)):
raise NotImplementedError("this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,", x) raise NotImplementedError("this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,", x)
return isinstance(x.type, SparseType) return isinstance(x.type, SparseType)
def _is_dense_variable(x): def _is_dense_variable(x):
""" """
@rtype: boolean @rtype: boolean
...@@ -50,6 +54,7 @@ def _is_dense_variable(x): ...@@ -50,6 +54,7 @@ def _is_dense_variable(x):
raise NotImplementedError("this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,", x) raise NotImplementedError("this function should only be called on *variables* (of type sparse.SparseType or tensor.TensorType), not,", x)
return isinstance(x.type, tensor.TensorType) return isinstance(x.type, tensor.TensorType)
def _is_sparse(x): def _is_sparse(x):
""" """
@rtype: boolean @rtype: boolean
...@@ -58,6 +63,8 @@ def _is_sparse(x): ...@@ -58,6 +63,8 @@ def _is_sparse(x):
if not isinstance(x, (scipy.sparse.spmatrix, numpy.ndarray)): if not isinstance(x, (scipy.sparse.spmatrix, numpy.ndarray)):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x) raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, scipy.sparse.spmatrix) return isinstance(x, scipy.sparse.spmatrix)
def _is_dense(x): def _is_dense(x):
""" """
@rtype: boolean @rtype: boolean
...@@ -67,18 +74,19 @@ def _is_dense(x): ...@@ -67,18 +74,19 @@ def _is_dense(x):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x) raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, numpy.ndarray) return isinstance(x, numpy.ndarray)
def _kmap_eq(a, b): def _kmap_eq(a, b):
if a is None and b is None: if a is None and b is None:
return True return True
return numpy.all(a == b) return numpy.all(a == b)
def _kmap_hash(a): def _kmap_hash(a):
if a is None: return 12345 if a is None: return 12345
return hash(numpy.str(a)) return hash(numpy.str(a))
# Wrapper type # Wrapper type
def as_sparse_variable(x, name=None): def as_sparse_variable(x, name=None):
""" """
Wrapper around SparseVariable constructor. Wrapper around SparseVariable constructor.
...@@ -101,9 +109,9 @@ def as_sparse_variable(x, name=None): ...@@ -101,9 +109,9 @@ def as_sparse_variable(x, name=None):
return constant(x, name=name) return constant(x, name=name)
except TypeError: except TypeError:
raise TypeError("Cannot convert %s to SparseType" % x, type(x)) raise TypeError("Cannot convert %s to SparseType" % x, type(x))
as_sparse = as_sparse_variable
as_sparse = as_sparse_variable
def as_sparse_or_tensor_variable(x, name=None): def as_sparse_or_tensor_variable(x, name=None):
""" """
If we can't make a sparse variable, we try to make a tensor variable. If we can't make a sparse variable, we try to make a tensor variable.
...@@ -133,6 +141,7 @@ if 0: ...@@ -133,6 +141,7 @@ if 0:
except TypeError: except TypeError:
raise TypeError("Could not convert %s to SparseType" % x, type(x)) raise TypeError("Could not convert %s to SparseType" % x, type(x))
def sp_ones_like(x): def sp_ones_like(x):
data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats
return CSM(format=x.format)(tensor.ones_like(data), indices, indptr, shape) return CSM(format=x.format)(tensor.ones_like(data), indices, indptr, shape)
...@@ -213,6 +222,7 @@ class SparseVariable(gof.Variable, _sparse_py_operators): ...@@ -213,6 +222,7 @@ class SparseVariable(gof.Variable, _sparse_py_operators):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
class SparseConstantSignature(tuple): class SparseConstantSignature(tuple):
def __eq__(self, other): def __eq__(self, other):
(a, b), (x,y) = self, other (a, b), (x,y) = self, other
...@@ -225,6 +235,7 @@ class SparseConstantSignature(tuple): ...@@ -225,6 +235,7 @@ class SparseConstantSignature(tuple):
(a,b) = self (a,b) = self
return hash(type(self)) ^ hash(a) ^ hash(type(b)) return hash(type(self)) ^ hash(a) ^ hash(type(b))
class SparseConstant(gof.Constant, _sparse_py_operators): class SparseConstant(gof.Constant, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
...@@ -242,10 +253,12 @@ class SparseConstant(gof.Constant, _sparse_py_operators): ...@@ -242,10 +253,12 @@ class SparseConstant(gof.Constant, _sparse_py_operators):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
class SparseValue(gof.Value, _sparse_py_operators): class SparseValue(gof.Value, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
class SparseType(gof.Type): class SparseType(gof.Type):
""" """
@type dtype: numpy dtype string such as 'int64' or 'float64' (among others) @type dtype: numpy dtype string such as 'int64' or 'float64' (among others)
...@@ -366,8 +379,12 @@ def matrix(format, name=None, dtype=None): ...@@ -366,8 +379,12 @@ def matrix(format, name=None, dtype=None):
dtype = config.floatX dtype = config.floatX
type = SparseType(format=format, dtype=dtype) type = SparseType(format=format, dtype=dtype)
return type(name) return type(name)
def csc_matrix(name=None, dtype=None): def csc_matrix(name=None, dtype=None):
return matrix('csc', name, dtype) return matrix('csc', name, dtype)
def csr_matrix(name=None, dtype=None): def csr_matrix(name=None, dtype=None):
return matrix('csr', name, dtype) return matrix('csr', name, dtype)
# for more dtypes, call SparseType(format, dtype) # for more dtypes, call SparseType(format, dtype)
...@@ -378,6 +395,7 @@ csr_dmatrix = SparseType(format='csr', dtype='float64') ...@@ -378,6 +395,7 @@ csr_dmatrix = SparseType(format='csr', dtype='float64')
csc_fmatrix = SparseType(format='csc', dtype='float32') csc_fmatrix = SparseType(format='csc', dtype='float32')
csr_fmatrix = SparseType(format='csr', dtype='float32') csr_fmatrix = SparseType(format='csr', dtype='float32')
# CONSTRUCTION # CONSTRUCTION
class CSMProperties(gof.Op): class CSMProperties(gof.Op):
"""Extract all of .data .indices and .indptr""" """Extract all of .data .indices and .indptr"""
...@@ -427,11 +445,20 @@ class CSMProperties(gof.Op): ...@@ -427,11 +445,20 @@ class CSMProperties(gof.Op):
else: else:
return [CSR('csm')(g_data, indices, indptr, shape)] return [CSR('csm')(g_data, indices, indptr, shape)]
csm_properties = CSMProperties() #don't make this a function or it breaks some optimizations below csm_properties = CSMProperties() #don't make this a function or it breaks some optimizations below
def csm_data(csm): return csm_properties(csm)[0] def csm_data(csm): return csm_properties(csm)[0]
def csm_indices(csm): return csm_properties(csm)[1] def csm_indices(csm): return csm_properties(csm)[1]
def csm_indptr(csm): return csm_properties(csm)[2] def csm_indptr(csm): return csm_properties(csm)[2]
def csm_shape(csm): return csm_properties(csm)[3] def csm_shape(csm): return csm_properties(csm)[3]
class CSM(gof.Op): class CSM(gof.Op):
"""Construct a CSC or CSR matrix from the internal representation """ """Construct a CSC or CSR matrix from the internal representation """
view_map = {0:[0]} #should view the other inputs too, but viewing multiple inputs is not view_map = {0:[0]} #should view the other inputs too, but viewing multiple inputs is not
...@@ -536,6 +563,7 @@ class CSM(gof.Op): ...@@ -536,6 +563,7 @@ class CSM(gof.Op):
CSC = CSM('csc') CSC = CSM('csc')
CSR = CSM('csr') CSR = CSM('csr')
class CSMGrad(gof.op.Op): class CSMGrad(gof.op.Op):
def __init__(self, kmap=None): def __init__(self, kmap=None):
self.kmap = kmap self.kmap = kmap
...@@ -563,6 +591,7 @@ class CSMGrad(gof.op.Op): ...@@ -563,6 +591,7 @@ class CSMGrad(gof.op.Op):
g_data[0] = grad g_data[0] = grad
csm_grad = CSMGrad csm_grad = CSMGrad
@gof.local_optimizer([csm_properties]) @gof.local_optimizer([csm_properties])
def skip_pack_csc01(node): def skip_pack_csc01(node):
"""if we find csm_properties(CSM(*args)), then we can replace that with the *args """if we find csm_properties(CSM(*args)), then we can replace that with the *args
...@@ -580,7 +609,6 @@ def skip_pack_csc01(node): ...@@ -580,7 +609,6 @@ def skip_pack_csc01(node):
register_specialize(skip_pack_csc01) register_specialize(skip_pack_csc01)
# #
# Conversion # Conversion
# #
...@@ -617,6 +645,7 @@ class DenseFromSparse(gof.op.Op): ...@@ -617,6 +645,7 @@ class DenseFromSparse(gof.op.Op):
return [ishape] return [ishape]
dense_from_sparse = DenseFromSparse() dense_from_sparse = DenseFromSparse()
class SparseFromDense(gof.op.Op): class SparseFromDense(gof.op.Op):
def __init__(self, format): def __init__(self, format):
self.format = format self.format = format
...@@ -817,6 +846,7 @@ class Transpose(gof.op.Op): ...@@ -817,6 +846,7 @@ class Transpose(gof.op.Op):
return transpose(gz), return transpose(gz),
transpose = Transpose() transpose = Transpose()
class Neg(gof.op.Op): class Neg(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -833,6 +863,7 @@ class Neg(gof.op.Op): ...@@ -833,6 +863,7 @@ class Neg(gof.op.Op):
return -gz, return -gz,
neg = Neg() neg = Neg()
class AddSS(gof.op.Op): class AddSS(gof.op.Op):
'''Add two sparse matrices ''' '''Add two sparse matrices '''
def __eq__(self, other): def __eq__(self, other):
...@@ -858,6 +889,8 @@ class AddSS(gof.op.Op): ...@@ -858,6 +889,8 @@ class AddSS(gof.op.Op):
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return gz, gz return gz, gz
add_s_s = AddSS() add_s_s = AddSS()
class AddSD(gof.op.Op): class AddSD(gof.op.Op):
''' Add a sparse and a dense matrix ''' ''' Add a sparse and a dense matrix '''
def __eq__(self, other): def __eq__(self, other):
...@@ -885,6 +918,8 @@ class AddSD(gof.op.Op): ...@@ -885,6 +918,8 @@ class AddSD(gof.op.Op):
assert _is_dense_variable(gz) assert _is_dense_variable(gz)
return sp_ones_like(x) * gz, gz return sp_ones_like(x) * gz, gz
add_s_d = AddSD() add_s_d = AddSD()
def add(x,y): def add(x,y):
""" """
Add two matrices, at least one of which is sparse. Add two matrices, at least one of which is sparse.
...@@ -900,11 +935,12 @@ def add(x,y): ...@@ -900,11 +935,12 @@ def add(x,y):
elif x_is_sparse_variable and not y_is_sparse_variable: return add_s_d(x,y) elif x_is_sparse_variable and not y_is_sparse_variable: return add_s_d(x,y)
elif y_is_sparse_variable and not x_is_sparse_variable: return add_s_d(y,x) elif y_is_sparse_variable and not x_is_sparse_variable: return add_s_d(y,x)
else: raise NotImplementedError() else: raise NotImplementedError()
def sub(x,y): def sub(x,y):
return x + (-y) return x + (-y)
class MulSS(gof.op.Op): class MulSS(gof.op.Op):
''' Elementwise multiply a sparse and a sparse ''' ''' Elementwise multiply a sparse and a sparse '''
def __eq__(self, other): def __eq__(self, other):
...@@ -928,6 +964,8 @@ class MulSS(gof.op.Op): ...@@ -928,6 +964,8 @@ class MulSS(gof.op.Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
return y * gz, x * gz return y * gz, x * gz
mul_s_s = MulSS() mul_s_s = MulSS()
class MulSD(gof.op.Op): class MulSD(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray ''' ''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other): def __eq__(self, other):
...@@ -995,6 +1033,8 @@ class MulSD(gof.op.Op): ...@@ -995,6 +1033,8 @@ class MulSD(gof.op.Op):
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return y * gz, x * gz return y * gz, x * gz
mul_s_d = MulSD() mul_s_d = MulSD()
def mul(x,y): def mul(x,y):
""" """
Multiply (elementwise) two matrices, at least one of which is sparse. Multiply (elementwise) two matrices, at least one of which is sparse.
...@@ -1011,6 +1051,7 @@ def mul(x,y): ...@@ -1011,6 +1051,7 @@ def mul(x,y):
elif y_is_sparse_variable and not x_is_sparse_variable: return mul_s_d(y,x) elif y_is_sparse_variable and not x_is_sparse_variable: return mul_s_d(y,x)
else: raise NotImplementedError() else: raise NotImplementedError()
############### ###############
# #
# StructuredDot # StructuredDot
...@@ -1073,9 +1114,9 @@ class StructuredDot(gof.Op): ...@@ -1073,9 +1114,9 @@ class StructuredDot(gof.Op):
# ga = g_out x b.T # ga = g_out x b.T
# gb = a.T x g_out # gb = a.T x g_out
return [structured_dot_grad(a, b, g_out), structured_dot(a.T,g_out)] return [structured_dot_grad(a, b, g_out), structured_dot(a.T,g_out)]
_structured_dot = StructuredDot() _structured_dot = StructuredDot()
def structured_dot(x, y): def structured_dot(x, y):
""" """
@todo: Maybe the triple-transposition formulation (when x is dense) @todo: Maybe the triple-transposition formulation (when x is dense)
...@@ -1096,6 +1137,7 @@ def structured_dot(x, y): ...@@ -1096,6 +1137,7 @@ def structured_dot(x, y):
assert y_is_sparse_variable assert y_is_sparse_variable
return _structured_dot(y.T, x.T).T return _structured_dot(y.T, x.T).T
class StructuredDotCSC(gof.Op): class StructuredDotCSC(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -1262,6 +1304,7 @@ class StructuredDotCSC(gof.Op): ...@@ -1262,6 +1304,7 @@ class StructuredDotCSC(gof.Op):
return (2,) return (2,)
sd_csc = StructuredDotCSC() sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op): class StructuredDotCSR(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -1394,6 +1437,7 @@ class StructuredDotCSR(gof.Op): ...@@ -1394,6 +1437,7 @@ class StructuredDotCSR(gof.Op):
return (1,) return (1,)
sd_csr = StructuredDotCSR() sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx # register a specialization to replace StructuredDot -> StructuredDotCSx
@gof.local_optimizer([_structured_dot]) @gof.local_optimizer([_structured_dot])
def local_structured_dot(node): def local_structured_dot(node):
...@@ -1414,6 +1458,7 @@ def local_structured_dot(node): ...@@ -1414,6 +1458,7 @@ def local_structured_dot(node):
# involved. dimension mismatches are hard to detect sensibly. # involved. dimension mismatches are hard to detect sensibly.
#register_specialize(local_structured_dot) #register_specialize(local_structured_dot)
def structured_dot_grad(sparse_A, dense_B, ga): def structured_dot_grad(sparse_A, dense_B, ga):
if sparse_A.type.format in ('csc','csr'): if sparse_A.type.format in ('csc','csr'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论