提交 21887b7d authored 作者: Iban Harlouchet's avatar Iban Harlouchet

Additional corrections to sparse/basic.py

上级 0bdc9521
...@@ -370,11 +370,11 @@ class SparseVariable(_sparse_py_operators, gof.Variable): ...@@ -370,11 +370,11 @@ class SparseVariable(_sparse_py_operators, gof.Variable):
class SparseConstantSignature(tuple): class SparseConstantSignature(tuple):
def __eq__(self, other): def __eq__(self, other):
(a, b), (x, y) = self, other (a, b), (x, y) = self, other
return a == x and\ return (a == x and
(b.dtype == y.dtype) and\ (b.dtype == y.dtype) and
(type(b) == type(y)) and\ (type(b) == type(y)) and
(b.shape == y.shape) and\ (b.shape == y.shape) and
(abs(b - y).sum() < 1e-6 * b.nnz) (abs(b - y).sum() < 1e-6 * b.nnz))
def __hash__(self): def __hash__(self):
(a, b) = self (a, b) = self
...@@ -489,7 +489,7 @@ class CSMProperties(gof.Op): ...@@ -489,7 +489,7 @@ class CSMProperties(gof.Op):
csm = as_sparse_variable(csm) csm = as_sparse_variable(csm)
assert csm.format in ["csr", "csc"] assert csm.format in ["csr", "csc"]
data = tensor.TensorType(dtype=csm.type.dtype, data = tensor.TensorType(dtype=csm.type.dtype,
broadcastable=(False,)).make_variable() broadcastable=(False,))()
return gof.Apply(self, [csm], return gof.Apply(self, [csm],
[data, tensor.ivector(), [data, tensor.ivector(),
tensor.ivector(), tensor.ivector()]) tensor.ivector(), tensor.ivector()])
...@@ -648,7 +648,7 @@ class CSM(gof.Op): ...@@ -648,7 +648,7 @@ class CSM(gof.Op):
return gof.Apply(self, return gof.Apply(self,
[data, indices, indptr, shape], [data, indices, indptr, shape],
[SparseType(dtype=data.type.dtype, [SparseType(dtype=data.type.dtype,
format=self.format).make_variable()]) format=self.format)()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
# for efficiency, if remap does nothing, then do not apply it # for efficiency, if remap does nothing, then do not apply it
...@@ -836,7 +836,7 @@ class Cast(gof.op.Op): ...@@ -836,7 +836,7 @@ class Cast(gof.op.Op):
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
return gof.Apply( return gof.Apply(
self, [x], self, [x],
[SparseType(dtype=self.out_type, format=x.format).make_variable()]) [SparseType(dtype=self.out_type, format=x.format)()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
...@@ -904,8 +904,8 @@ class DenseFromSparse(gof.op.Op): ...@@ -904,8 +904,8 @@ class DenseFromSparse(gof.op.Op):
self.sparse_grad = structured self.sparse_grad = structured
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) and \ return ((type(self) == type(other)) and
(self.sparse_grad == other.sparse_grad) (self.sparse_grad == other.sparse_grad))
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.sparse_grad) return hash(type(self)) ^ hash(self.sparse_grad)
...@@ -920,8 +920,7 @@ class DenseFromSparse(gof.op.Op): ...@@ -920,8 +920,7 @@ class DenseFromSparse(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x], [x],
[tensor.TensorType(dtype=x.type.dtype, [tensor.TensorType(dtype=x.type.dtype,
broadcastable=(False, False) broadcastable=(False, False))()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
...@@ -1004,8 +1003,7 @@ class SparseFromDense(gof.op.Op): ...@@ -1004,8 +1003,7 @@ class SparseFromDense(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x], [x],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=x.type.dtype,
format=self.format format=self.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
...@@ -1440,8 +1438,7 @@ class Transpose(gof.op.Op): ...@@ -1440,8 +1438,7 @@ class Transpose(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x], [x],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=x.type.dtype,
format=self.format_map[x.type.format] format=self.format_map[x.type.format])()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
...@@ -1961,8 +1958,7 @@ class AddSS(gof.op.Op): ...@@ -1961,8 +1958,7 @@ class AddSS(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=out_dtype, [SparseType(dtype=out_dtype,
format=x.type.format format=x.type.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2003,8 +1999,7 @@ class AddSSData(gof.op.Op): ...@@ -2003,8 +1999,7 @@ class AddSSData(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=x.type.dtype,
format=x.type.format format=x.type.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2070,7 +2065,7 @@ class AddSD(gof.op.Op): ...@@ -2070,7 +2065,7 @@ class AddSD(gof.op.Op):
[x, y], [x, y],
[tensor.TensorType(dtype=out_dtype, [tensor.TensorType(dtype=out_dtype,
broadcastable=y.type.broadcastable broadcastable=y.type.broadcastable
).make_variable()]) )()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2113,8 +2108,7 @@ class StructuredAddSV(gof.op.Op): ...@@ -2113,8 +2108,7 @@ class StructuredAddSV(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=x.type.dtype,
format=x.type.format format=x.type.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2370,8 +2364,7 @@ class MulSV(gof.op.Op): ...@@ -2370,8 +2364,7 @@ class MulSV(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=x.type.dtype,
format=x.type.format format=x.type.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2486,8 +2479,7 @@ class __ComparisonOpSS(gof.op.Op): ...@@ -2486,8 +2479,7 @@ class __ComparisonOpSS(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype='uint8', [SparseType(dtype='uint8',
format=x.type.format format=x.type.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2531,8 +2523,7 @@ class __ComparisonOpSD(gof.op.Op): ...@@ -2531,8 +2523,7 @@ class __ComparisonOpSD(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype='uint8', [SparseType(dtype='uint8',
format=x.type.format format=x.type.format)()])
).make_variable()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x, y) = inputs (x, y) = inputs
...@@ -2773,8 +2764,7 @@ class HStack(gof.op.Op): ...@@ -2773,8 +2764,7 @@ class HStack(gof.op.Op):
return gof.Apply(self, return gof.Apply(self,
var, var,
[SparseType(dtype=self.dtype, [SparseType(dtype=self.dtype,
format=self.format format=self.format)()])
).make_variable()])
def perform(self, node, block, outputs): def perform(self, node, block, outputs):
(out,) = outputs (out,) = outputs
...@@ -3220,8 +3210,7 @@ class TrueDot(gof.op.Op): ...@@ -3220,8 +3210,7 @@ class TrueDot(gof.op.Op):
raise NotImplementedError() raise NotImplementedError()
inputs = [x, y] # Need to convert? e.g. assparse inputs = [x, y] # Need to convert? e.g. assparse
outputs = [SparseType(dtype=x.type.dtype, outputs = [SparseType(dtype=x.type.dtype, format=myformat)()]
format=myformat).make_variable()]
return gof.Apply(self, inputs, outputs) return gof.Apply(self, inputs, outputs)
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论