提交 ae6a8a7d authored 作者: James Bergstra's avatar James Bergstra

added eq, hash, ne functions to several sparse ops

上级 16d68785
......@@ -64,6 +64,14 @@ def _is_dense(x):
raise NotImplementedError("this function should only be called on sparse.scipy.sparse.spmatrix or numpy.ndarray, not,", x)
return isinstance(x, numpy.ndarray)
def _kmap_eq(a, b):
if a is None and b is None:
return True
return numpy.all(a == b)
def _kmap_hash(a):
if a is None: return 12345
return hash(numpy.str(a))
# Wrapper type
......@@ -211,11 +219,23 @@ class SparseValue(gof.Value, _sparse_py_operators):
# CONSTRUCTION
class CSMProperties(gof.Op):
"""Extract all of .data .indices and .indptr"""
view_map = {0:[0],1:[0],2:[0],3:[0]}
kmap = None
""" WRITEME """
def __init__(self, kmap=None):
self.kmap = kmap
def __eq__(self, other):
return type(self) == type(other) and _kmap_eq(self.kmap, other.kmap)
def __ne__(self, other): return not (self == other)
def __hash__(self):
return 8234 ^ hash(type(self)) ^ _kmap_hash(self.kmap)
def make_node(self, csm):
csm = as_sparse(csm)
data = tensor.Tensor(dtype=csm.type.dtype, broadcastable = (False,)).make_result()
......@@ -248,6 +268,15 @@ class CSM(gof.Op):
view_map = {0:[0]} #should view the other inputs too, but viewing multiple inputs is not
#currently supported by the destroyhandler
format = None
"""WRITEME"""
kmap = None
"""WRITEME"""
_hashval = None
"""Pre-computed hash value, defined by __init__"""
def __init__(self, format, kmap=None):
if format not in ('csr', 'csc'):
raise ValueError("format must be one of: 'csr', 'csc'", format)
......@@ -259,12 +288,14 @@ class CSM(gof.Op):
self.kmap = kmap
self._hashval = hash(type(self)) ^ hash(self.format) ^ _kmap_hash(self.kmap)
def __eq__(self, other):
return type(other) is CSM \
and other.format == self.format and numpy.all(other.kmap==self.kmap)
and other.format == self.format and _kmap_eq(self.kmap, other.kmap)
def __hash__(self):
return hash(type(self)) ^ hash(self.format) ^ hash(numpy.str(self.kmap))
return self._hashval
def make_node(self, data, indices, indptr, shape):
"""Build a SparseResult from the internal parametrization
......@@ -335,6 +366,15 @@ class CSMGrad(gof.op.Op):
def __init__(self, kmap=None):
self.kmap = kmap
def __eq__(self, other):
return type(self) == type(other) and _kmap_eq(self.kmap, other.kmap)
def __ne__(self, other): return not (self == other)
def __hash__(self):
return 82345 ^ hash(type(self)) ^ _kmap_hash(self.kmap)
def make_node(self, data, gout_data, gout_indices):
g_data = data.type()
return gof.Apply(self, [data, gout_data, gout_indices], [g_data])
......@@ -369,6 +409,8 @@ class DenseFromSparse(gof.op.Op):
Convert a sparse matrix to an `ndarray`.
"""
sparse_grad = True
"""WRITEME"""
def make_node(self, x):
x = as_sparse(x)
return gof.Apply(self,
......@@ -392,6 +434,13 @@ dense_from_sparse = DenseFromSparse()
class SparseFromDense(gof.op.Op):
def __init__(self, format):
self.format = format
def __eq__(self, other):
return type(self) == type(other) and self.format == other.format
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return 982374 ^ hash(self.format) ^ hash(DenseFromSparse)
def make_node(self, x):
x = tensor.as_tensor(x)
return gof.Apply(self,
......@@ -402,10 +451,6 @@ class SparseFromDense(gof.op.Op):
out[0] = Sparse.format_cls[self.format](x)
def grad(self, (x, ), (gz, )):
return dense_from_sparse(gz),
def __eq__(self, other):
return type(self) == type(other) and self.format == other.format
def __hash__(self):
return hash(self.format) ^ hash(DenseFromSparse)
csr_from_dense = SparseFromDense('csr')
csc_from_dense = SparseFromDense('csc')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论