提交 fa4617bf authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where DisconnectedInputs passed in by known_grads may not

raise an error (the correct value was still computed though)
上级 7e2f4871
...@@ -482,18 +482,11 @@ def grad(cost, wrt, consider_constant=None, ...@@ -482,18 +482,11 @@ def grad(cost, wrt, consider_constant=None,
grad_dict[var] = g_var grad_dict[var] = g_var
def handle_disconnected(var):
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
and elem not in grad_dict:
message = ("grad method was asked to compute the gradient " message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of " "with respect to a variable that is not part of "
"the computational graph of the cost, or is used " "the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % elem) "only by a non-differentiable operator: %s" % var)
if disconnected_inputs == 'ignore': if disconnected_inputs == 'ignore':
pass pass
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
...@@ -504,6 +497,16 @@ def grad(cost, wrt, consider_constant=None, ...@@ -504,6 +497,16 @@ def grad(cost, wrt, consider_constant=None,
raise ValueError("Invalid value for keyword " raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are " "'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
and elem not in grad_dict:
handle_disconnected(elem)
grad_dict[elem] = DisconnectedType()() grad_dict[elem] = DisconnectedType()()
cost_name = None cost_name = None
...@@ -523,6 +526,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -523,6 +526,7 @@ def grad(cost, wrt, consider_constant=None,
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i])
if return_disconnected == 'zero': if return_disconnected == 'zero':
rval[i] = _float_zeros_like(wrt[i]) rval[i] = _float_zeros_like(wrt[i])
elif return_disconnected == 'None': elif return_disconnected == 'None':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论