提交 fcb0bc9e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

MakeVector grad

上级 c63048fc
...@@ -542,15 +542,12 @@ class MakeVector(T.Op): ...@@ -542,15 +542,12 @@ class MakeVector(T.Op):
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass # If the output is of an integer dtype, no gradient shall pass
if 'int' in self.dtype: if 'int' in self.dtype:
return [None] * len(inputs) return [ipt.zeros_like().astype(theano.config.floatX)
for ipt in inputs]
grads = [] grads = []
for i, inp in enumerate(inputs): for i, inp in enumerate(inputs):
if 'int' in inp.dtype: grads.append(output_gradients[0][i])
# No gradient wrt integer inputs
grads.append(None)
else:
grads.append(output_gradients[0][i])
return grads return grads
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论