提交 e62834b9 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

added structured_monoid to help write some elementwise operations

上级 f1643043
...@@ -361,6 +361,9 @@ class Sum(gof.op.Op): ...@@ -361,6 +361,9 @@ class Sum(gof.op.Op):
def perform(self, node, (x, a), (out, )): def perform(self, node, (x, a), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = numpy.asarray(x.sum(a), dtype=x.dtype).flatten() out[0] = numpy.asarray(x.sum(a), dtype=x.dtype).flatten()
def grad(self, (x, a, ), (gz, )):
return None, None
sum = Sum() sum = Sum()
...@@ -394,66 +397,69 @@ class Binomial(gof.op.Op): ...@@ -394,66 +397,69 @@ class Binomial(gof.op.Op):
out[0] = getattr(res, 'to' + self.format)() out[0] = getattr(res, 'to' + self.format)()
out[0].data = numpy.ones_like(out[0].data) out[0].data = numpy.ones_like(out[0].data)
def grad(self, (n, p, shape, ), (gz,)):
return None, None, None
csr_fbinomial = Binomial('csr', 'float32') csr_fbinomial = Binomial('csr', 'float32')
csc_fbinomial = Binomial('csc', 'float32') csc_fbinomial = Binomial('csc', 'float32')
csr_dbinomial = Binomial('csr', 'float64') csr_dbinomial = Binomial('csr', 'float64')
csc_dbinomial = Binomial('csc', 'float64') csc_dbinomial = Binomial('csc', 'float64')
def structured_sigmoid(x): def structured_monoid(tensor_op):
""" """
Element-wise sigmoid function only to the non-zero elements. Generic operation to perform many kinds of monoid element-wise
operations on the non-zeros of a sparse matrix.
The first parameter must always be a sparse matrix. The other parameters
must be scalars which will be passed as argument to the tensor_op.
""" """
x = as_sparse_variable(x) def decorator(f):
def wrapper(*args):
x = as_sparse_variable(args[0])
x_data, x_ind, x_ptr, x_shape = csm_properties(x) xs = [tensor.as_tensor_variable(arg) for arg in args[1:]]
x_data = tensor.nnet.sigmoid(x_data) data, ind, ptr, shape = csm_properties(x)
return CSR(x_data, x_ind, x_ptr, x_shape) data = tensor_op(data, *xs)
return CSR(data, ind, ptr, shape)
return wrapper
return decorator
def structured_exp(x):
"""
Element-wise exponential function to the non-zero elements.
"""
x = as_sparse_variable(x)
x_data, x_ind, x_ptr, x_shape = csm_properties(x)
x_data = tensor.exp(x_data) @structured_monoid(tensor.nnet.sigmoid)
def structured_sigmoid(x):
"""structured elemwise sigmoid.
"""
# see decorator for function body
return CSR(x_data, x_ind, x_ptr, x_shape) @structured_monoid(tensor.exp)
def structured_exp(x):
"""structured elemwise exponential.
"""
# see decorator for function body
@structured_monoid(tensor.log)
def structured_log(x):
"""structured elemwise logarithm.
"""
# see decorator for function body
@structured_monoid(tensor.pow)
def structured_pow(x, y): def structured_pow(x, y):
"""structured elemwise power of sparse matrix
x by scalar y.
""" """
Element-wise power function only to non-zero elements. # see decorator for function body
"""
x = as_sparse_variable(x)
y = tensor.as_tensor_variable(y)
x_data, x_ind, x_ptr, x_shape = csm_properties(x)
x_data = tensor.pow(x_data, y)
return CSR(x_data, x_ind, x_ptr, x_shape)
@structured_monoid(tensor.minimum)
def structured_minimum(x, y): def structured_minimum(x, y):
"""structured elemwise minimum of sparse matrix
x by scalar y.
""" """
Element-wise minimum function only to non-zero elements. # see decorator for function body
"""
x = as_sparse_variable(x)
y = tensor.as_tensor_variable(y)
x_data, x_ind, x_ptr, x_shape = csm_properties(x)
x_data = tensor.minimum(x_data, y)
return CSR(x_data, x_ind, x_ptr, x_shape)
class StructuredAddSV(gof.op.Op): class StructuredAddSV(gof.op.Op):
...@@ -486,9 +492,9 @@ class StructuredAddSV(gof.op.Op): ...@@ -486,9 +492,9 @@ class StructuredAddSV(gof.op.Op):
out[0] = x.__class__(x + (x.toarray() != 0) * y) out[0] = x.__class__(x + (x.toarray() != 0) * y)
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(y) assert _is_sparse_variable(x) and not _is_sparse_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return gz, gz return gz, gz.sum(0)
structured_add_s_v = StructuredAddSV() structured_add_s_v = StructuredAddSV()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论