提交 09a3c807 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added check that Op.grad and Op.connection_pattern agree

上级 34430def
......@@ -734,7 +734,38 @@ def _populate_grad_dict(var_to_node_to_idx,
# we won't be able to post-process out the Nones if it does that
input_grads = list(input_grads)
#Do type checking on the result
# Do type checking on the result
#Check that op.connection_pattern matches the connectivity
#logic driving the op.grad method
for i, packed in \
enumerate(zip(inputs, input_grads, inputs_connected)):
ipt, ig, connected = packed
actually_connected = \
not isinstance(ig.type, DisconnectedType)
if actually_connected and not connected:
msg = "%s.grad returned %s of type %s for input %d."
msg += " Expected DisconnectedType instance based on "
msg += " the output of the op's connection_pattern "
msg += "method."
msg = msg % (str(node.op), str(ig), str(ig.type), i)
raise TypeError(msg)
if connected and not actually_connected:
msg = "%s.grad returned DisconnectedType for input"
msg += " %d."
msg = msg % (str(node.op), i)
if hasattr(node.op,'connection_pattern'):
msg += ' Its connection_pattern method does not'
msg += ' allow this.'
raise TypeError(msg)
else:
msg += ' You may want to implement a '
msg += ' connection_pattern method for it.'
warnings.warn(msg)
# Process out any Nones
for i, term in enumerate(input_grads):
if term is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论