提交 077bc706 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made gradient computation raise an explicit ValueError instead of an

IndexError if it gets the wrong number of gradient terms from an Op
上级 166bd8ba
...@@ -25,8 +25,6 @@ from theano.gof.nan_type import NaNType ...@@ -25,8 +25,6 @@ from theano.gof.nan_type import NaNType
from theano.printing import min_informative_str from theano.printing import min_informative_str
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients'
def format_as(use_list, use_tuple, outputs): def format_as(use_list, use_tuple, outputs):
""" """
...@@ -547,9 +545,14 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -547,9 +545,14 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
output_grads = [ access_grad_cache(var) for var in node.outputs ] output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads) input_grads = node.op.grad(inputs, output_grads)
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op))
#must convert to list in case the op returns a tuple #must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that #we won't be able to post-process out the Nones if it does that
term_dict[node] = list(input_grads) term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])): for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None: if term_dict[node][i] is None:
term_dict[node][i] = node.inputs[i].zeros_like() term_dict[node][i] = node.inputs[i].zeros_like()
...@@ -661,10 +664,18 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -661,10 +664,18 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
#populate term_dict[node] and return it #populate term_dict[node] and return it
def access_term_cache(node): def access_term_cache(node):
if node not in term_dict: if node not in term_dict:
inputs = node.inputs
input_grads = node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs])
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op))
#must convert to list in case the op returns a tuple #must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that #we won't be able to post-process out the Nones if it does that
term_dict[node] = list(node.op.grad(node.inputs, term_dict[node] = list(input_grads)
[access_grad_cache(var) for var in node.outputs]))
for i in xrange(len(term_dict[node])): for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None: if term_dict[node][i] is None:
term_dict[node][i] = node.inputs[i].zeros_like() term_dict[node][i] = node.inputs[i].zeros_like()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论