提交 b11dd92e authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

replaced if

上级 dfa37dea
...@@ -561,10 +561,7 @@ class CSMProperties(gof.Op): ...@@ -561,10 +561,7 @@ class CSMProperties(gof.Op):
def grad(self, (csm,), g): def grad(self, (csm,), g):
assert [gg is None for gg in g[1:]] assert [gg is None for gg in g[1:]]
data, indices, indptr, shape = csm_properties(csm) data, indices, indptr, shape = csm_properties(csm)
if csm.format == 'csc': return [CSM(csm.format)(g[0], indices, indptr, shape)]
return [CSC(g[0], indices, indptr, shape)]
else:
return [CSR(g[0], indices, indptr, shape)]
# don't make this a function or it breaks some optimizations below # don't make this a function or it breaks some optimizations below
csm_properties = CSMProperties() csm_properties = CSMProperties()
...@@ -1270,7 +1267,7 @@ class MulSD(gof.op.Op): ...@@ -1270,7 +1267,7 @@ class MulSD(gof.op.Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_dense_variable(y) assert _is_sparse_variable(x) and _is_dense_variable(y)
#assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return y * gz, x * gz return y * gz, x * gz
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论