提交 5f78c193 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made grad_sources_inputs postprocess out the DisconnectedTypes

上级 936d499a
...@@ -738,6 +738,12 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True): ...@@ -738,6 +738,12 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True):
_populate_grad_dict(var_to_node_to_idx, _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type) grad_dict, wrt, warn_type)
#post-process out the DisconnectedTypes
for key in grad_dict:
if isinstance(grad_dict[key].type,DisconnectedType):
if hasattr(key,'zeros_like'):
grad_dict[key] = key.zeros_like()
return grad_dict return grad_dict
class numeric_grad(object): class numeric_grad(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论