提交 d4ea7573 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

CSMProperties Op: simplify logic

上级 c1a53d7d
...@@ -297,15 +297,12 @@ class CSMProperties(Op): ...@@ -297,15 +297,12 @@ class CSMProperties(Op):
csm = as_sparse_variable(csm) csm = as_sparse_variable(csm)
assert csm.format in ("csr", "csc") assert csm.format in ("csr", "csc")
data = TensorType(dtype=csm.type.dtype, shape=(None,))() data = vector(dtype=csm.type.dtype)
return Apply(self, [csm], [data, ivector(), ivector(), ivector()]) return Apply(self, [csm], [data, ivector(), ivector(), ivector()])
def perform(self, node, inputs, out): def perform(self, node, inputs, out):
(csm,) = inputs (csm,) = inputs
out[0][0] = csm.data out[0][0] = np.asarray(csm.data)
if str(csm.data.dtype) == "int32":
out[0][0] = np.asarray(out[0][0], dtype="int32")
# backport
out[1][0] = np.asarray(csm.indices, dtype="int32") out[1][0] = np.asarray(csm.indices, dtype="int32")
out[2][0] = np.asarray(csm.indptr, dtype="int32") out[2][0] = np.asarray(csm.indptr, dtype="int32")
out[3][0] = np.asarray(csm.shape, dtype="int32") out[3][0] = np.asarray(csm.shape, dtype="int32")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论