Renamed SparseR to SparseResult

上级 fac5027f
...@@ -18,23 +18,23 @@ import tensor ...@@ -18,23 +18,23 @@ import tensor
def assparse(sp, **kwargs): def assparse(sp, **kwargs):
""" """
Wrapper around SparseR constructor. Wrapper around SparseResult constructor.
@param sp: A sparse matrix. assparse reads dtype and format properties @param sp: A sparse matrix. assparse reads dtype and format properties
out of this sparse matrix. out of this sparse matrix.
@return: SparseR version of sp. @return: SparseResult version of sp.
@todo Verify that sp is sufficiently sparse, and raise a warning if it is not @todo Verify that sp is sufficiently sparse, and raise a warning if it is not
""" """
if isinstance(sp, SparseR): if isinstance(sp, SparseResult):
return sp return sp
else: else:
# @todo Verify that sp is sufficiently sparse, and raise a # @todo Verify that sp is sufficiently sparse, and raise a
# warning if it is not # warning if it is not
rval = SparseR(str(sp.dtype), sp.format, **kwargs) rval = SparseResult(str(sp.dtype), sp.format, **kwargs)
rval.data = sp rval.data = sp
return rval return rval
class SparseR(gof.result.Result): class SparseResult(gof.result.Result):
""" """
Attribute: Attribute:
format - a string identifying the type of sparsity format - a string identifying the type of sparsity
...@@ -58,26 +58,26 @@ class SparseR(gof.result.Result): ...@@ -58,26 +58,26 @@ class SparseR(gof.result.Result):
Fundamental way to do create a sparse node. Fundamental way to do create a sparse node.
@param dtype: Type of numbers in the matrix. @param dtype: Type of numbers in the matrix.
@param format: The sparse storage strategy. @param format: The sparse storage strategy.
@return An empty SparseR instance. @return An empty SparseResult instance.
""" """
gof.Result.__init__(self, **kwargs) gof.Result.__init__(self, **kwargs)
if dtype in SparseR.dtype_set: if dtype in SparseResult.dtype_set:
self._dtype = dtype self._dtype = dtype
assert isinstance(format, str) assert isinstance(format, str)
#print format, type(format), SparseR.format_cls.keys(), format in SparseR.format_cls #print format, type(format), SparseResult.format_cls.keys(), format in SparseResult.format_cls
if format in SparseR.format_cls: if format in SparseResult.format_cls:
self._format = format self._format = format
else: else:
raise NotImplementedError('unsupported format "%s" not in list' % format, SparseR.format_cls.keys()) raise NotImplementedError('unsupported format "%s" not in list' % format, SparseResult.format_cls.keys())
def filter(self, value): def filter(self, value):
if isinstance(value, SparseR.format_cls[self.format])\ if isinstance(value, SparseResult.format_cls[self.format])\
and value.dtype == self.dtype: and value.dtype == self.dtype:
return value return value
#print 'pass-through failed', type(value) #print 'pass-through failed', type(value)
sp = SparseR.format_cls[self.format](value) sp = SparseResult.format_cls[self.format](value)
if str(sp.dtype) != self.dtype: if str(sp.dtype) != self.dtype:
raise NotImplementedError() raise NotImplementedError()
if sp.format != self.format: if sp.format != self.format:
...@@ -86,9 +86,9 @@ class SparseR(gof.result.Result): ...@@ -86,9 +86,9 @@ class SparseR(gof.result.Result):
def __copy__(self): def __copy__(self):
if self.name is not None: if self.name is not None:
rval = SparseR(self._dtype, self._format, name=self.name) rval = SparseResult(self._dtype, self._format, name=self.name)
else: else:
rval = SparseR(self._dtype, self._format) rval = SparseResult(self._dtype, self._format)
rval.data = copy.copy(self.data) rval.data = copy.copy(self.data)
return rval return rval
...@@ -126,11 +126,11 @@ class SparseFromDense(gof.op.Op): ...@@ -126,11 +126,11 @@ class SparseFromDense(gof.op.Op):
else: else:
self.inputs = [tensor.astensor(x), gof.result.PythonResult()] self.inputs = [tensor.astensor(x), gof.result.PythonResult()]
self.inputs[1].data = format self.inputs[1].data = format
self.outputs = [SparseR(x.dtype, self.inputs[1].data)] self.outputs = [SparseResult(x.dtype, self.inputs[1].data)]
def impl(self, x, fmt): def impl(self, x, fmt):
# this would actually happen anyway when we try to assign to # this would actually happen anyway when we try to assign to
# self.outputs[0].data, but that seems hackish -JB # self.outputs[0].data, but that seems hackish -JB
return SparseR.format_cls[fmt](x) return SparseResult.format_cls[fmt](x)
def grad(self, (x, fmt), gz): def grad(self, (x, fmt), gz):
return dense_from_sparse(gz) return dense_from_sparse(gz)
sparse_from_dense = gof.op.constructor(SparseFromDense) sparse_from_dense = gof.op.constructor(SparseFromDense)
...@@ -145,7 +145,7 @@ class Transpose(gof.op.Op): ...@@ -145,7 +145,7 @@ class Transpose(gof.op.Op):
gof.op.Op.__init__(self, **kwargs) gof.op.Op.__init__(self, **kwargs)
x = assparse(x) x = assparse(x)
self.inputs = [x] self.inputs = [x]
self.outputs = [SparseR(x.dtype, Transpose.format_map[x.format])] self.outputs = [SparseResult(x.dtype, Transpose.format_map[x.format])]
def impl(self, x): def impl(self, x):
return x.transpose() return x.transpose()
def grad(self, x, gz): def grad(self, x, gz):
...@@ -161,7 +161,7 @@ class AddSS(gof.op.Op): #add two sparse matrices ...@@ -161,7 +161,7 @@ class AddSS(gof.op.Op): #add two sparse matrices
raise NotImplementedError() raise NotImplementedError()
if x.format != y.format: if x.format != y.format:
raise NotImplementedError() raise NotImplementedError()
self.outputs = [SparseR(x.dtype, x.format)] self.outputs = [SparseResult(x.dtype, x.format)]
def impl(self, x,y): def impl(self, x,y):
return x + y return x + y
def grad(self, (x, y), gz): def grad(self, (x, y), gz):
...@@ -174,14 +174,14 @@ class Dot(gof.op.Op): ...@@ -174,14 +174,14 @@ class Dot(gof.op.Op):
grad_preserves_dense - a boolean flags [default: True]. grad_preserves_dense - a boolean flags [default: True].
grad_preserves_dense controls whether gradients with respect to inputs grad_preserves_dense controls whether gradients with respect to inputs
are converted to dense matrices when the corresponding input y is are converted to dense matrices when the corresponding input y is
dense (not in a L{SparseR} wrapper). This is generally a good idea dense (not in a L{SparseResult} wrapper). This is generally a good idea
when L{Dot} is in the middle of a larger graph, because the types when L{Dot} is in the middle of a larger graph, because the types
of gy will match that of y. This conversion might be inefficient if of gy will match that of y. This conversion might be inefficient if
the gradients are graph outputs though, hence this mask. the gradients are graph outputs though, hence this mask.
""" """
def __init__(self, x, y, grad_preserves_dense=True): def __init__(self, x, y, grad_preserves_dense=True):
""" """
Because of trickiness of implementing, we assume that the left argument x is SparseR (not dense) Because of trickiness of implementing, we assume that the left argument x is SparseResult (not dense)
""" """
if x.dtype != y.dtype: if x.dtype != y.dtype:
raise NotImplementedError() raise NotImplementedError()
...@@ -195,7 +195,7 @@ class Dot(gof.op.Op): ...@@ -195,7 +195,7 @@ class Dot(gof.op.Op):
raise NotImplementedError() raise NotImplementedError()
self.inputs = [x, y] # Need to convert? e.g. assparse self.inputs = [x, y] # Need to convert? e.g. assparse
self.outputs = [SparseR(x.dtype, myformat)] self.outputs = [SparseResult(x.dtype, myformat)]
self.grad_preserves_dense = grad_preserves_dense self.grad_preserves_dense = grad_preserves_dense
def perform(self): def perform(self):
""" """
...@@ -205,8 +205,8 @@ class Dot(gof.op.Op): ...@@ -205,8 +205,8 @@ class Dot(gof.op.Op):
self.outputs[0].data = self.inputs[0].data.dot(self.inputs[1].data) self.outputs[0].data = self.inputs[0].data.dot(self.inputs[1].data)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
rval = [dot(gz, y.T), dot(x.T, gz)] rval = [dot(gz, y.T), dot(x.T, gz)]
assert isinstance(self.inputs[0], SparseR) assert isinstance(self.inputs[0], SparseResult)
if not isinstance(self.inputs[1], SparseR): if not isinstance(self.inputs[1], SparseResult):
if self.grad_preserves_dense: if self.grad_preserves_dense:
rval[1] = dense_from_sparse(rval[1]) rval[1] = dense_from_sparse(rval[1])
return rval return rval
...@@ -222,8 +222,8 @@ def dot(x, y, grad_preserves_dense=True): ...@@ -222,8 +222,8 @@ def dot(x, y, grad_preserves_dense=True):
if hasattr(x, 'getnnz'): x = assparse(x) if hasattr(x, 'getnnz'): x = assparse(x)
if hasattr(y, 'getnnz'): y = assparse(y) if hasattr(y, 'getnnz'): y = assparse(y)
x_is_sparse = isinstance(x, SparseR) x_is_sparse = isinstance(x, SparseResult)
y_is_sparse = isinstance(y, SparseR) y_is_sparse = isinstance(y, SparseResult)
if not x_is_sparse and not y_is_sparse: if not x_is_sparse and not y_is_sparse:
raise TypeError() raise TypeError()
if x_is_sparse: if x_is_sparse:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论