提交 4990aaf9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

implemented correction_pattern and corrected grad for Subtensor

上级 d1879e8b
...@@ -3978,7 +3978,7 @@ class Subtensor(Op): ...@@ -3978,7 +3978,7 @@ class Subtensor(Op):
x = inputs[0] x = inputs[0]
rest = inputs[1:] rest = inputs[1:]
return ([IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] return ([IncSubtensor(self.idx_list)(zeros_like(x), gz, *rest)]
+ [None] * len(rest)) + [DisconnectedType()()] * len(rest))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list return type(self) == type(other) and self.idx_list == other.idx_list
...@@ -4632,6 +4632,15 @@ class IncSubtensor(Op): ...@@ -4632,6 +4632,15 @@ class IncSubtensor(Op):
return self.make_node(eval_points[0], eval_points[1], return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs *inputs[2:]).outputs
def connection_pattern(self, node):
rval = [ [True] ]
for ipt in node.inputs[1:]:
rval.append([False])
return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
g_output, = grads g_output, = grads
x, y = inputs[:2] x, y = inputs[:2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论