提交 345e4745 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made tensor.grad automatically detect if an op is entirely disconnected

上级 5f78c193
...@@ -558,15 +558,28 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -558,15 +558,28 @@ def _populate_grad_dict(var_to_node_to_idx,\
inputs = [try_to_copy(ipt) for ipt in inputs ] inputs = [try_to_copy(ipt) for ipt in inputs ]
output_grads = [ access_grad_cache(var) for var in node.outputs ] output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads)
if input_grads is None: if False in [ isinstance(g.type, DisconnectedType)
raise TypeError("%s.grad returned NoneType, " for g in output_grads ]:
"expected iterable." % str(node.op)) #Some outputs of this op are connected to the cost so we must
#call the ops grad method
if len(input_grads) != len(inputs): input_grads = node.op.grad(inputs, output_grads)
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op)) if input_grads is None:
raise TypeError("%s.grad returned NoneType, "
"expected iterable." % str(node.op))
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op))
else:
#All outputs of this op are disconnected so we can skip
#Calling the op's grad method and report that the inputs
#are disconnected
#(The op's grad method could do this too, but this saves the
#implementer the trouble of worrying about this case)
input_grads = [ DisconnectedType()() for ipt in inputs ]
#must convert to list in case the op returns a tuple #must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that #we won't be able to post-process out the Nones if it does that
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论