提交 0bced96f authored 作者: Frederic's avatar Frederic

Update DiagonalSubtensor.grad to don't return None.

上级 aa315810
from theano.gradient import DisconnectedType
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano import tensor from theano import tensor
def get_diagonal_subtensor_view(x, i0, i1): def get_diagonal_subtensor_view(x, i0, i1):
if x.shape[i0] < x.shape[i1]: if x.shape[i0] < x.shape[i1]:
raise NotImplementedError('is this allowed?') raise NotImplementedError('is this allowed?')
...@@ -31,10 +33,16 @@ class DiagonalSubtensor(Op): ...@@ -31,10 +33,16 @@ class DiagonalSubtensor(Op):
output_storage[0][0] = xview output_storage[0][0] = xview
else: else:
output_storage[0][0] = xview.copy() output_storage[0][0] = xview.copy()
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
z = tensor.zeros_like(inputs[0]) z = tensor.zeros_like(inputs[0])
gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0]) gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0])
return [gx] + [None] * (len(inputs)-1) return [gx, DisconnectedType()(), DisconnectedType()()]
def connection_pattern(self, node):
rval = [[True], [False], [False]]
return rval
diagonal_subtensor = DiagonalSubtensor(False) diagonal_subtensor = DiagonalSubtensor(False)
class IncDiagonalSubtensor(Op): class IncDiagonalSubtensor(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论