提交 da41c9da authored 作者: Ian Goodfellow's avatar Ian Goodfellow

rearranged access_term_cache to have a section clearly devoted to type

checking
上级 ccdcc319
...@@ -716,11 +716,12 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -716,11 +716,12 @@ def _populate_grad_dict(var_to_node_to_idx,
# must convert to list in case the op returns a tuple # must convert to list in case the op returns a tuple
# we won't be able to post-process out the Nones if it does that # we won't be able to post-process out the Nones if it does that
term_dict[node] = list(input_grads) input_grads = list(input_grads)
for i in xrange(len(term_dict[node])): #Do type checking on the result
for i, term in enumerate(input_grads):
if term_dict[node][i] is None: if term is None:
# we don't know what None means. in the past it has been # we don't know what None means. in the past it has been
# used to # used to
# mean undefined, zero, or disconnected. So for now we # mean undefined, zero, or disconnected. So for now we
...@@ -730,10 +731,10 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -730,10 +731,10 @@ def _populate_grad_dict(var_to_node_to_idx,
# eventually we should disallow this # eventually we should disallow this
# return type and force all ops # return type and force all ops
# to return the correct thing # to return the correct thing
# raise AssertionError('%s returned None for' +\ #raise AssertionError(('%s returned None for' +\
# ' a gradient term, ' # ' a gradient term, '
# 'this is prohibited' % node.op) # 'this is prohibited') % node.op)
term_dict[node][i] = node.inputs[i].zeros_like() input_grads[i] = node.inputs[i].zeros_like()
if warn_type: if warn_type:
g_r_type = term_dict[node][i].type g_r_type = term_dict[node][i].type
...@@ -744,6 +745,9 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -744,6 +745,9 @@ def _populate_grad_dict(var_to_node_to_idx,
'for input %i of type (%s)', 'for input %i of type (%s)',
node.op, g_r_type, i, r_type) node.op, g_r_type, i, r_type)
#cache the result
term_dict[node] = list(input_grads)
return term_dict[node] return term_dict[node]
# populate grad_dict[var] and return it # populate grad_dict[var] and return it
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论