提交 84c40292 authored 作者: james@X40's avatar james@X40

very ugly hack to __mul__ to illustrate how to decide between sparse and dense operators

上级 46a148d3
......@@ -163,21 +163,28 @@ class Sparse(gof.Type):
def __repr__(self):
return "Sparse[%s, %s]" % (str(self.dtype), str(self.format))
csc_matrix = Sparse(format='csc')
csr_matrix = Sparse(format='csr')
class _sparse_py_operators:
T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)")
def __add__(left, right): return add(left, right)
def __radd__(right, left): return add(left, right)
def __mul__(left, right): return mul(left, right)
def __rmul__(left, right): return mul(left, right)
class SparseResult(gof.Result, _sparse_py_operators):
pass
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseConstant(gof.Constant, _sparse_py_operators):
pass
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseValue(gof.Value, _sparse_py_operators):
pass
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
......@@ -193,7 +200,7 @@ class DenseFromSparse(gof.op.Op):
[tensor.Tensor(dtype = x.type.dtype,
broadcastable = (False, False)).make_result()])
def perform(self, node, (x, ), (out, )):
out[0] = numpy.asarray(x.todense())
out[0] = x.toarray()
def grad(self, (x, ), (gz, )):
return SparseFromDense(x.type.format)(gz),
dense_from_sparse = DenseFromSparse()
......@@ -296,6 +303,48 @@ def add(x,y):
elif y_is_sparse_result and not x_is_sparse_result: return add_s_d(y,x)
else: raise NotImplementedError()
class MulSD(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
def make_node(self, x, y):
x, y = as_sparse(x), tensor.as_tensor(y)
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
# The magic number two here arises because L{scipy.sparse}
# objects must be matrices (have dimension 2)
# Broadcasting of the sparse matrix is not supported.
assert y.type.ndim <= 2
return gof.Apply(self, [x, y], [x.type()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_dense(y)
if len(y.shape) == 0:
out[0] = x.copy()
out[0].data *= y
elif len(y.shape) == 1:
raise NotImplementedError() #RowScale / ColScale
elif len(y.shape) == 2:
#if we have enough memory to fit y, maybe we can fit x.asarray() too?
#TODO: change runtime from O(M*N) to O(nonzeros)
return DenseFromSparse(x) * y
def grad(self, (x, y), (gz,)):
assert _is_sparse_result(x) and _is_dense_result(y)
assert _is_dense_result(gz)
return SparseFromDense(x.type.format)(gz), gz
mul_s_d = MulSD()
def mul(x,y):
"""
Multiply (elementwise) two matrices, at least one of which is sparse.
"""
if hasattr(x, 'getnnz'): x = as_sparse(x)
if hasattr(y, 'getnnz'): y = as_sparse(y)
x_is_sparse_result = _is_sparse_result(x)
y_is_sparse_result = _is_sparse_result(y)
assert x_is_sparse_result or y_is_sparse_result
if x_is_sparse_result and y_is_sparse_result: return mul_s_s(x,y)
elif x_is_sparse_result and not y_is_sparse_result: return mul_s_d(x,y)
elif y_is_sparse_result and not x_is_sparse_result: return mul_s_d(y,x)
else: raise NotImplementedError()
class Dot(gof.op.Op):
"""
......
......@@ -492,7 +492,14 @@ class _tensor_py_operators:
#ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other)
def __sub__(self,other): return sub(self,other)
def __mul__(self,other): return mul(self,other)
def __mul__(self,other):
try:
return mul(self,other)
except Exception, e:
try:
return other * self
except:
raise e
def __div__(self,other): return div(self,other)
def __pow__(self,other): return pow(self,other)
def __mod__(self,other): return mod(self,other)
......@@ -2189,7 +2196,8 @@ def verify_grad(testcase, op, pt, n_tests=1, rng=numpy.random, eps=1.0e-7, tol=0
testcase.failUnless(analytic gradient matches finite-diff gradient)
:param pt: the list of numpy.ndarrays to use as inputs to the op
:param op: something that behaves like an Op instance.
:param op: something that behaves like an Op instance with a single output (can be a
function)
:param testcase: the thing to call `fail` on if things go awry.
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论