提交 9bbd74d7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

got rid of non-empty sum, replaced with reduce(lambda,terms)

上级 78ac9eb5
...@@ -652,15 +652,6 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -652,15 +652,6 @@ def _populate_grad_dict(var_to_node_to_idx,
return term_dict[node] return term_dict[node]
# built-in python sum adds an extraneous TensorConstant(0)
# we can exploit the knowledge that iterable always has at
# least one element to avoid starting the sum at 0
def nonempty_sum(iterable):
rval = iterable[0]
for elem in iterable[1:]:
rval = rval + elem
return rval
# populate grad_dict[var] and return it # populate grad_dict[var] and return it
def access_grad_cache(var): def access_grad_cache(var):
if var not in grad_dict: if var not in grad_dict:
...@@ -688,7 +679,9 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -688,7 +679,9 @@ def _populate_grad_dict(var_to_node_to_idx,
term.type.why_null) term.type.why_null)
terms.append(term) terms.append(term)
grad_dict[var] = nonempty_sum(terms) #the next line is like sum(terms) but doesn't add an
#extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x,y: x+y, terms)
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论