提交 5382cc34 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier 提交者: --global

Implement connection_pattern

上级 36ca71e1
......@@ -93,8 +93,19 @@ class PdbBreakpoint(Op):
output_storage[i][0] = monitored[i]
def grad(self, inputs, output_gradients):
return ([DisconnectedType()] + output_gradients)
return ([DisconnectedType()()] + output_gradients)
def infer_shape(self, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
return input_shapes[1:]
def connection_pattern(self, node):
nb_inp = len(node.inputs)
nb_out = nb_inp - 1
# First input is connected to no output and every other input n is
# connected to input n-1
connections = [[out_idx == inp_idx - 1 for out_idx in range(nb_out)]
for inp_idx in range(nb_inp)]
return connections
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论