提交 42080b8d authored 作者: James Bergstra's avatar James Bergstra

sparse operators neg and sub, and changed definition of grad(sparse.add)

上级 58f2b986
......@@ -128,6 +128,7 @@ def __src_version__():
def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """
rval = NotImplemented
e0, e1 = None, None
if rval == NotImplemented and hasattr(l, '__dot__'):
try:
......
......@@ -6,7 +6,7 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
@todo: Automatic methods for determining best sparse format?
"""
import sys
import sys, operator
import numpy
from scipy import sparse
......@@ -177,8 +177,11 @@ csr_matrix = Sparse(format='csr')
class _sparse_py_operators:
T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)")
def __neg__(self): return neg(self)
def __add__(left, right): return add(left, right)
def __radd__(right, left): return add(left, right)
def __sub__(left, right): return sub(left, right)
def __rsub__(right, left): return sub(left, right)
def __mul__(left, right): return mul(left, right)
def __rmul__(left, right): return mul(left, right)
......@@ -410,8 +413,20 @@ class Transpose(gof.op.Op):
return transpose(gz),
transpose = Transpose()
class Neg(gof.op.Op):
def make_node(self, x):
x = as_sparse(x)
return gof.Apply(self, [x], [x.type()])
def perform(self, node, (x, ), (out, )):
assert _is_sparse(x)
out[0] = -x
def grad(self, (x,), (gz,)):
assert _is_sparse_result(x) and _is_sparse_result(gz)
return -gz,
neg = Neg()
class AddSS(gof.op.Op):
''' Add two sparse matrices '''
'''Add two sparse matrices '''
def make_node(self, x, y):
x, y = map(as_sparse, [x, y])
if x.type.dtype != y.type.dtype:
......@@ -450,7 +465,7 @@ class AddSD(gof.op.Op):
def grad(self, (x, y), (gz,)):
assert _is_sparse_result(x) and _is_dense_result(y)
assert _is_dense_result(gz)
return SparseFromDense(x.type.format)(gz), gz
return sp_one_like(x) * gz, gz
add_s_d = AddSD()
def add(x,y):
"""
......@@ -467,6 +482,10 @@ def add(x,y):
elif x_is_sparse_result and not y_is_sparse_result: return add_s_d(x,y)
elif y_is_sparse_result and not x_is_sparse_result: return add_s_d(y,x)
else: raise NotImplementedError()
def sub(x,y):
return x + (-y)
class MulSS(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论