提交 b11dd92e authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

replaced if

上级 dfa37dea
......@@ -561,10 +561,7 @@ class CSMProperties(gof.Op):
def grad(self, (csm,), g):
assert [gg is None for gg in g[1:]]
data, indices, indptr, shape = csm_properties(csm)
if csm.format == 'csc':
return [CSC(g[0], indices, indptr, shape)]
else:
return [CSR(g[0], indices, indptr, shape)]
return [CSM(csm.format)(g[0], indices, indptr, shape)]
# don't make this a function or it breaks some optimizations below
csm_properties = CSMProperties()
......@@ -1270,7 +1267,7 @@ class MulSD(gof.op.Op):
def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_dense_variable(y)
#assert _is_sparse_variable(gz)
assert _is_sparse_variable(gz)
return y * gz, x * gz
def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论