提交 9adbcad2 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

fixes to various gradients including samplingdot

上级 57e65e2b
......@@ -363,7 +363,7 @@ class Sum(gof.op.Op):
out[0] = numpy.asarray(x.sum(a), dtype=x.dtype).flatten()
def grad(self, (x, a, ), (gz, )):
return None, None
return sp_ones_like(x) * gz, None
sum = Sum()
......@@ -424,7 +424,7 @@ def structured_monoid(tensor_op):
data = tensor_op(data, *xs)
return CSR(data, ind, ptr, shape)
return CSM(x.format)(data, ind, ptr, shape)
return wrapper
return decorator
......@@ -461,6 +461,19 @@ def structured_minimum(x, y):
"""
# see decorator for function body
@structured_monoid(tensor.maximum)
def structured_maximum(x, y):
"""structured elemwise maximum of sparse matrix
x by scalar y.
"""
# see decorator for function body
@structured_monoid(tensor.add)
def structured_add(x):
"""structured addition of sparse matrix
x and scalar y.
"""
# see decorator for function body
class StructuredAddSV(gof.op.Op):
'''Structured addition of a sparse matrix and a dense vector.
......@@ -676,8 +689,8 @@ class SamplingDot(gof.op.Op):
def grad(self, (x, y, p), (gz,)):
rval = [
dot(gz, y),
dot(gz.T, x),
dot(p * gz, y),
dot(p.T * gz.T, x),
None
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论