提交 7d84b870 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added grad method to extract_diag

上级 7a0b8177
...@@ -438,6 +438,8 @@ class AllocDiag(Op): ...@@ -438,6 +438,8 @@ class AllocDiag(Op):
if x.type.ndim != 1: if x.type.ndim != 1:
raise TypeError('AllocDiag only works on vectors', _x) raise TypeError('AllocDiag only works on vectors', _x)
return Apply(self, [x], [tensor.matrix(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.matrix(dtype=x.type.dtype)])
def grad(self, inputs, g_outputs):
return [extract_diag(g_outputs[0])]
def perform(self, node, (x,), (z,)): def perform(self, node, (x,), (z,)):
if x.ndim != 1: if x.ndim != 1:
raise TypeError(x) raise TypeError(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论