提交 0d4824de authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

use as_scalar for StructuredMonoid

上级 171c015c
...@@ -11,7 +11,7 @@ from theano.sparse.basic import ( ...@@ -11,7 +11,7 @@ from theano.sparse.basic import (
_is_sparse_variable, CSC, CSR, _is_sparse_variable, CSC, CSR,
csm_properties, csm_data, csm_indices, csm_indptr, csm_shape, csm_properties, csm_data, csm_indices, csm_indptr, csm_shape,
_is_sparse) _is_sparse)
from theano.sparse.sandbox.sp import sp_sum
class Cast(gof.op.Op): class Cast(gof.op.Op):
def __init__(self, out_type): def __init__(self, out_type):
...@@ -396,7 +396,7 @@ def structured_monoid(tensor_op): ...@@ -396,7 +396,7 @@ def structured_monoid(tensor_op):
def wrapper(*args): def wrapper(*args):
x = as_sparse_variable(args[0]) x = as_sparse_variable(args[0])
xs = [tensor.as_tensor_variable(arg) for arg in args[1:]] xs = [scalar.as_scalar(arg) for arg in args[1:]]
data, ind, ptr, shape = csm_properties(x) data, ind, ptr, shape = csm_properties(x)
...@@ -485,7 +485,7 @@ class StructuredAddSV(gof.op.Op): ...@@ -485,7 +485,7 @@ class StructuredAddSV(gof.op.Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and not _is_sparse_variable(y) assert _is_sparse_variable(x) and not _is_sparse_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return gz, sum(gz, 0) return gz, sp_sum(gz, axis=0, sparse_grad=True)
structured_add_s_v = StructuredAddSV() structured_add_s_v = StructuredAddSV()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论