提交 167cf345 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Implement grad of Alloc.

上级 5548fee1
...@@ -2472,7 +2472,11 @@ class Alloc(gof.Op): ...@@ -2472,7 +2472,11 @@ class Alloc(gof.Op):
return [node.inputs[1:]] return [node.inputs[1:]]
def grad(self, inputs, grads): def grad(self, inputs, grads):
return [None for i in inputs] x = inputs[0]
gz = grads[0]
n_axes_to_sum = gz.ndim - x.ndim
gx = gz.sum(axis=range(n_axes_to_sum))
return [gx] + [None for i in inputs[1:]]
def __call__(self, val, *shapes): def __call__(self, val, *shapes):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论