提交 85b6b618 authored 作者: Frederic's avatar Frederic

pep8

上级 4d879ee9
...@@ -33,19 +33,21 @@ def local_add_s_s(node): ...@@ -33,19 +33,21 @@ def local_add_s_s(node):
""" """
If two matrices are known to have the same sparsity pattern, If two matrices are known to have the same sparsity pattern,
optimize the addition by only adding their data vector. optimize the addition by only adding their data vector.
Very special case optimization. Activate when for add(x, y), Very special case optimization. Activate when for add(x, y),
y is an expression like sp_ones_like(x) * another_matrix. y is an expression like sp_ones_like(x) * another_matrix.
This is useful for sparse weight updates. This is useful for sparse weight updates.
Work also for add(x, neg(y)) in the same case. Work also for add(x, neg(y)) in the same case.
As of this writting sub is only implemented as x + neg(y) for sparse matrix.
As of this writting sub is only implemented as x + neg(y) for
sparse matrix.
""" """
if node.op == add_s_s: if node.op == add_s_s:
x, y = node.inputs x, y = node.inputs
# In case addition was transformed to subtraction # In case addition was transformed to subtraction
if hasattr(y.owner, 'op') and y.owner.op == neg: if hasattr(y.owner, 'op') and y.owner.op == neg:
y_ = y.owner.inputs[0] y_ = y.owner.inputs[0]
else: else:
...@@ -54,38 +56,46 @@ def local_add_s_s(node): ...@@ -54,38 +56,46 @@ def local_add_s_s(node):
return False return False
if hasattr(y_.owner, 'op') and y_.owner.op not in [mul_s_s, mul_s_d]: if hasattr(y_.owner, 'op') and y_.owner.op not in [mul_s_s, mul_s_d]:
return False return False
def same_pattern(node): def same_pattern(node):
"""Check node has same sparsity as x.""" """Check node has same sparsity as x."""
# In case the sparse matrix is multiplied by a scalar (ex: learning rate) # In case the sparse matrix is multiplied by a scalar (ex:
# learning rate)
if hasattr(node.owner, 'op') and node.owner.op == mul_scalar: if hasattr(node.owner, 'op') and node.owner.op == mul_scalar:
node = node.owner.inputs[1] node = node.owner.inputs[1]
# Check node creates a matrix # Check node creates a matrix
if not hasattr(node.owner, 'op') or not isinstance(node.owner.op, CSM): if not hasattr(node.owner, 'op') or not isinstance(node.owner.op,
return False CSM):
return False
# Check matrix is creates from CSMProperties # Check matrix is creates from CSMProperties
if filter(lambda i: not hasattr(i.owner, 'op') or not isinstance(i.owner.op, CSMProperties), node.owner.inputs[1:]): if filter(lambda i: not hasattr(i.owner, 'op') or
return False not isinstance(i.owner.op, CSMProperties),
node.owner.inputs[1:]):
return False
# Verify indices, indptr and shape are the same as x # Verify indices, indptr and shape are the same as x
if filter(lambda i: i.owner.inputs[0] != x, node.owner.inputs[1:]): if filter(lambda i: i.owner.inputs[0] != x, node.owner.inputs[1:]):
return False return False
return True return True
if filter(same_pattern, y_.owner.inputs): if filter(same_pattern, y_.owner.inputs):
return [add_s_s_data(x, y)] return [add_s_s_data(x, y)]
return False return False
register_specialize(local_add_s_s) register_specialize(local_add_s_s)
class AddSSData(gof.op.Op): class AddSSData(gof.op.Op):
'''Add two sparse matrices assuming they have the same sparsity pattern. ''' '''Add two sparse matrices assuming they have the same sparsity
pattern. '''
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x, y): def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y]) x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
...@@ -94,46 +104,50 @@ class AddSSData(gof.op.Op): ...@@ -94,46 +104,50 @@ class AddSSData(gof.op.Op):
raise NotImplementedError() raise NotImplementedError()
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype = x.type.dtype, [SparseType(dtype=x.type.dtype,
format = x.type.format).make_variable()]) format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y) assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape assert x.shape == y.shape
out[0] = x.copy() out[0] = x.copy()
out[0].data += y.data out[0].data += y.data
add_s_s_data = AddSSData() add_s_s_data = AddSSData()
# register a specialization to replace MulSD -> MulSDCSX # register a specialization to replace MulSD -> MulSDCSX
@gof.local_optimizer([mul_s_d]) @gof.local_optimizer([mul_s_d])
def local_mul_s_d(node): def local_mul_s_d(node):
if node.op == mul_s_d: if node.op == mul_s_d:
x, y = node.inputs x, y = node.inputs
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y) # y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable: if x_is_sparse_variable:
svar = x svar = x
dvar = y dvar = y
else: else:
svar = y svar = y
dvar = x dvar = x
if dvar.type.ndim != 2: if dvar.type.ndim != 2:
return False return False
if svar.type.format == 'csc': if svar.type.format == 'csc':
CSx = CSC CSx = CSC
mul_s_d_csx = mul_s_d_csc mul_s_d_csx = mul_s_d_csc
elif svar.type.format == 'csr': elif svar.type.format == 'csr':
CSx = CSR CSx = CSR
mul_s_d_csx = mul_s_d_csr mul_s_d_csx = mul_s_d_csr
else: else:
raise NotImplemented() raise NotImplemented()
c_data = mul_s_d_csx(csm_data(svar), csm_indices(svar), csm_indptr(svar), dvar) c_data = mul_s_d_csx(csm_data(svar), csm_indices(svar),
csm_indptr(svar), dvar)
return [CSx(c_data, csm_indices(svar), csm_indptr(svar), csm_shape(svar))]
return [CSx(c_data, csm_indices(svar), csm_indptr(svar),
csm_shape(svar))]
return False return False
register_specialize(local_mul_s_d) register_specialize(local_mul_s_d)
...@@ -141,15 +155,19 @@ register_specialize(local_mul_s_d) ...@@ -141,15 +155,19 @@ register_specialize(local_mul_s_d)
class MulSDCSC(gof.Op): class MulSDCSC(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b): def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 2 assert b.type.ndim == 2
return gof.Apply(self, [a_data, a_indices, a_indptr, b], return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
#def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)): #def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplementedError() # return NotImplementedError()
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub): def c_code(self, node, name, (_data, _indices, _indptr, _b,),
(_zout, ), sub):
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a') raise NotImplementedError('Complex types are not supported for a')
...@@ -209,22 +227,26 @@ class MulSDCSC(gof.Op): ...@@ -209,22 +227,26 @@ class MulSDCSC(gof.Op):
} }
} }
"""% dict(locals(), **sub) """ % dict(locals(), **sub)
mul_s_d_csc = MulSDCSC() mul_s_d_csc = MulSDCSC()
class MulSDCSR(gof.Op): class MulSDCSR(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b): def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 2 assert b.type.ndim == 2
return gof.Apply(self, [a_data, a_indices, a_indptr, b], return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))]) [tensor.tensor(b.dtype, (False,))])
#def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)): #def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplemented() # return NotImplemented()
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub): def c_code(self, node, name, (_data, _indices, _indptr, _b,),
(_zout, ), sub):
if node.inputs[0].type.dtype in ('complex64', 'complex128'): if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a') raise NotImplementedError('Complex types are not supported for a')
...@@ -284,9 +306,10 @@ class MulSDCSR(gof.Op): ...@@ -284,9 +306,10 @@ class MulSDCSR(gof.Op):
} }
} }
"""% dict(locals(), **sub) """ % dict(locals(), **sub)
mul_s_d_csr = MulSDCSR() mul_s_d_csr = MulSDCSR()
class Poisson(gof.op.Op): class Poisson(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -306,6 +329,7 @@ class Poisson(gof.op.Op): ...@@ -306,6 +329,7 @@ class Poisson(gof.op.Op):
out[0].eliminate_zeros() out[0].eliminate_zeros()
poisson = Poisson() poisson = Poisson()
class Multinomial(gof.op.Op): class Multinomial(gof.op.Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论