提交 ae6a8a7d authored 作者: James Bergstra's avatar James Bergstra

added eq, hash, ne functions to several sparse ops

上级 16d68785
...@@ -64,6 +64,14 @@ def _is_dense(x): ...@@ -64,6 +64,14 @@ def _is_dense(x):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x) raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, numpy.ndarray) return isinstance(x, numpy.ndarray)
def _kmap_eq(a, b):
if a is None and b is None:
return True
return numpy.all(a == b)
def _kmap_hash(a):
if a is None: return 12345
return hash(numpy.str(a))
# Wrapper type # Wrapper type
...@@ -211,11 +219,23 @@ class SparseValue(gof.Value, _sparse_py_operators): ...@@ -211,11 +219,23 @@ class SparseValue(gof.Value, _sparse_py_operators):
# CONSTRUCTION # CONSTRUCTION
class CSMProperties(gof.Op): class CSMProperties(gof.Op):
"""Extract all of .data .indices and .indptr""" """Extract all of .data .indices and .indptr"""
view_map = {0:[0],1:[0],2:[0],3:[0]} view_map = {0:[0],1:[0],2:[0],3:[0]}
kmap = None
""" WRITEME """
def __init__(self, kmap=None): def __init__(self, kmap=None):
self.kmap = kmap self.kmap = kmap
def __eq__(self, other):
return type(self) == type(other) and _kmap_eq(self.kmap, other.kmap)
def __ne__(self, other): return not (self == other)
def __hash__(self):
return 8234 ^ hash(type(self)) ^ _kmap_hash(self.kmap)
def make_node(self, csm): def make_node(self, csm):
csm = as_sparse(csm) csm = as_sparse(csm)
data = tensor.Tensor(dtype=csm.type.dtype, broadcastable = (False,)).make_result() data = tensor.Tensor(dtype=csm.type.dtype, broadcastable = (False,)).make_result()
...@@ -248,6 +268,15 @@ class CSM(gof.Op): ...@@ -248,6 +268,15 @@ class CSM(gof.Op):
view_map = {0:[0]} #should view the other inputs too, but viewing multiple inputs is not view_map = {0:[0]} #should view the other inputs too, but viewing multiple inputs is not
#currently supported by the destroyhandler #currently supported by the destroyhandler
format = None
"""WRITEME"""
kmap = None
"""WRITEME"""
_hashval = None
"""Pre-computed hash value, defined by __init__"""
def __init__(self, format, kmap=None): def __init__(self, format, kmap=None):
if format not in ('csr', 'csc'): if format not in ('csr', 'csc'):
raise ValueError("format must be one of: 'csr', 'csc'", format) raise ValueError("format must be one of: 'csr', 'csc'", format)
...@@ -259,12 +288,14 @@ class CSM(gof.Op): ...@@ -259,12 +288,14 @@ class CSM(gof.Op):
self.kmap = kmap self.kmap = kmap
self._hashval = hash(type(self)) ^ hash(self.format) ^ _kmap_hash(self.kmap)
def __eq__(self, other): def __eq__(self, other):
return type(other) is CSM \ return type(other) is CSM \
and other.format == self.format and numpy.all(other.kmap==self.kmap) and other.format == self.format and _kmap_eq(self.kmap, other.kmap)
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.format) ^ hash(numpy.str(self.kmap)) return self._hashval
def make_node(self, data, indices, indptr, shape): def make_node(self, data, indices, indptr, shape):
"""Build a SparseResult from the internal parametrization """Build a SparseResult from the internal parametrization
...@@ -335,6 +366,15 @@ class CSMGrad(gof.op.Op): ...@@ -335,6 +366,15 @@ class CSMGrad(gof.op.Op):
def __init__(self, kmap=None): def __init__(self, kmap=None):
self.kmap = kmap self.kmap = kmap
def __eq__(self, other):
return type(self) == type(other) and _kmap_eq(self.kmap, other.kmap)
def __ne__(self, other): return not (self == other)
def __hash__(self):
return 82345 ^ hash(type(self)) ^ _kmap_hash(self.kmap)
def make_node(self, data, gout_data, gout_indices): def make_node(self, data, gout_data, gout_indices):
g_data = data.type() g_data = data.type()
return gof.Apply(self, [data, gout_data, gout_indices], [g_data]) return gof.Apply(self, [data, gout_data, gout_indices], [g_data])
...@@ -369,6 +409,8 @@ class DenseFromSparse(gof.op.Op): ...@@ -369,6 +409,8 @@ class DenseFromSparse(gof.op.Op):
Convert a sparse matrix to an `ndarray`. Convert a sparse matrix to an `ndarray`.
""" """
sparse_grad = True sparse_grad = True
"""WRITEME"""
def make_node(self, x): def make_node(self, x):
x = as_sparse(x) x = as_sparse(x)
return gof.Apply(self, return gof.Apply(self,
...@@ -392,6 +434,13 @@ dense_from_sparse = DenseFromSparse() ...@@ -392,6 +434,13 @@ dense_from_sparse = DenseFromSparse()
class SparseFromDense(gof.op.Op): class SparseFromDense(gof.op.Op):
def __init__(self, format): def __init__(self, format):
self.format = format self.format = format
def __eq__(self, other):
return type(self) == type(other) and self.format == other.format
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return 982374 ^ hash(self.format) ^ hash(DenseFromSparse)
def make_node(self, x): def make_node(self, x):
x = tensor.as_tensor(x) x = tensor.as_tensor(x)
return gof.Apply(self, return gof.Apply(self,
...@@ -402,10 +451,6 @@ class SparseFromDense(gof.op.Op): ...@@ -402,10 +451,6 @@ class SparseFromDense(gof.op.Op):
out[0] = Sparse.format_cls[self.format](x) out[0] = Sparse.format_cls[self.format](x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return dense_from_sparse(gz), return dense_from_sparse(gz),
def __eq__(self, other):
return type(self) == type(other) and self.format == other.format
def __hash__(self):
return hash(self.format) ^ hash(DenseFromSparse)
csr_from_dense = SparseFromDense('csr') csr_from_dense = SparseFromDense('csr')
csc_from_dense = SparseFromDense('csc') csc_from_dense = SparseFromDense('csc')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论