提交 85b6b618 authored 作者: Frederic's avatar Frederic

pep8

上级 4d879ee9
......@@ -40,7 +40,9 @@ def local_add_s_s(node):
This is useful for sparse weight updates.
Work also for add(x, neg(y)) in the same case.
As of this writting sub is only implemented as x + neg(y) for sparse matrix.
As of this writting sub is only implemented as x + neg(y) for
sparse matrix.
"""
if node.op == add_s_s:
x, y = node.inputs
......@@ -57,16 +59,20 @@ def local_add_s_s(node):
def same_pattern(node):
"""Check node has same sparsity as x."""
# In case the sparse matrix is multiplied by a scalar (ex: learning rate)
# In case the sparse matrix is multiplied by a scalar (ex:
# learning rate)
if hasattr(node.owner, 'op') and node.owner.op == mul_scalar:
node = node.owner.inputs[1]
# Check node creates a matrix
if not hasattr(node.owner, 'op') or not isinstance(node.owner.op, CSM):
if not hasattr(node.owner, 'op') or not isinstance(node.owner.op,
CSM):
return False
# Check matrix is creates from CSMProperties
if filter(lambda i: not hasattr(i.owner, 'op') or not isinstance(i.owner.op, CSMProperties), node.owner.inputs[1:]):
if filter(lambda i: not hasattr(i.owner, 'op') or
not isinstance(i.owner.op, CSMProperties),
node.owner.inputs[1:]):
return False
# Verify indices, indptr and shape are the same as x
......@@ -80,12 +86,16 @@ def local_add_s_s(node):
return False
register_specialize(local_add_s_s)
class AddSSData(gof.op.Op):
'''Add two sparse matrices assuming they have the same sparsity pattern. '''
'''Add two sparse matrices assuming they have the same sparsity
pattern. '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype:
......@@ -94,8 +104,9 @@ class AddSSData(gof.op.Op):
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype = x.type.dtype,
format = x.type.format).make_variable()])
[SparseType(dtype=x.type.dtype,
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape
......@@ -103,6 +114,7 @@ class AddSSData(gof.op.Op):
out[0].data += y.data
add_s_s_data = AddSSData()
# register a specialization to replace MulSD -> MulSDCSX
@gof.local_optimizer([mul_s_d])
def local_mul_s_d(node):
......@@ -110,7 +122,7 @@ def local_mul_s_d(node):
x, y = node.inputs
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
# y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable:
svar = x
......@@ -130,9 +142,11 @@ def local_mul_s_d(node):
else:
raise NotImplemented()
c_data = mul_s_d_csx(csm_data(svar), csm_indices(svar), csm_indptr(svar), dvar)
c_data = mul_s_d_csx(csm_data(svar), csm_indices(svar),
csm_indptr(svar), dvar)
return [CSx(c_data, csm_indices(svar), csm_indptr(svar), csm_shape(svar))]
return [CSx(c_data, csm_indices(svar), csm_indptr(svar),
csm_shape(svar))]
return False
register_specialize(local_mul_s_d)
......@@ -141,15 +155,19 @@ register_specialize(local_mul_s_d)
class MulSDCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 2
return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))])
#def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplementedError()
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub):
def c_code(self, node, name, (_data, _indices, _indptr, _b,),
(_zout, ), sub):
if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a')
......@@ -209,22 +227,26 @@ class MulSDCSC(gof.Op):
}
}
"""% dict(locals(), **sub)
""" % dict(locals(), **sub)
mul_s_d_csc = MulSDCSC()
class MulSDCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_data, a_indices, a_indptr, b):
assert b.type.ndim == 2
return gof.Apply(self, [a_data, a_indices, a_indptr, b],
[tensor.tensor(b.dtype, (False,))])
#def perform(self, node, (a_data, a_indices, a_indptr, b), (out,)):
# return NotImplemented()
def c_code(self, node, name, (_data, _indices, _indptr, _b,), (_zout, ), sub):
def c_code(self, node, name, (_data, _indices, _indptr, _b,),
(_zout, ), sub):
if node.inputs[0].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for a')
......@@ -284,9 +306,10 @@ class MulSDCSR(gof.Op):
}
}
"""% dict(locals(), **sub)
""" % dict(locals(), **sub)
mul_s_d_csr = MulSDCSR()
class Poisson(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
......@@ -306,6 +329,7 @@ class Poisson(gof.op.Op):
out[0].eliminate_zeros()
poisson = Poisson()
class Multinomial(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论