提交 0bced96f authored 作者: Frederic's avatar Frederic

Update DiagonalSubtensor.grad to don't return None.

上级 aa315810
from theano.gradient import DisconnectedType
from theano.gof import Op, Apply
from theano import tensor
def get_diagonal_subtensor_view(x, i0, i1):
if x.shape[i0] < x.shape[i1]:
raise NotImplementedError('is this allowed?')
......@@ -31,10 +33,16 @@ class DiagonalSubtensor(Op):
output_storage[0][0] = xview
else:
output_storage[0][0] = xview.copy()
def grad(self, inputs, g_outputs):
z = tensor.zeros_like(inputs[0])
gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0])
return [gx] + [None] * (len(inputs)-1)
return [gx, DisconnectedType()(), DisconnectedType()()]
def connection_pattern(self, node):
rval = [[True], [False], [False]]
return rval
diagonal_subtensor = DiagonalSubtensor(False)
class IncDiagonalSubtensor(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论