提交 077bc706 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made gradient computation raise an explicit ValueError instead of an

IndexError if it gets the wrong number of gradient terms from an Op
上级 166bd8ba
......@@ -25,8 +25,6 @@ from theano.gof.nan_type import NaNType
from theano.printing import min_informative_str
_msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients'
def format_as(use_list, use_tuple, outputs):
"""
......@@ -547,9 +545,14 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads)
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op))
#must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = node.inputs[i].zeros_like()
......@@ -661,10 +664,18 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
#populate term_dict[node] and return it
def access_term_cache(node):
if node not in term_dict:
inputs = node.inputs
input_grads = node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs])
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op))
#must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs]))
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = node.inputs[i].zeros_like()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论