提交 29e35e59 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where grad raised error for undefined grad not on path to wrt

上级 799f4b7d
......@@ -554,9 +554,6 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
if isinstance(term_dict[node][i].type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term_dict[node][i].type.why_nan)
return term_dict[node]
......@@ -576,7 +573,13 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
terms = []
for child in var._children.keys():
idx = var._children[child]
terms.append( access_term_cache(child)[idx])
term = access_term_cache(child)[idx]
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan)
terms.append( term)
grad_dict[var] = nonempty_sum(terms)
if cost.name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name)
......@@ -671,9 +674,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
if isinstance(term_dict[node][i].type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term_dict[node][i].type.why_nan)
return term_dict[node]
......@@ -693,7 +693,13 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
terms = []
for child in var._children.keys():
idx = var._children[child]
terms.append( access_term_cache(child)[idx])
term = access_term_cache(child)[idx]
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan)
terms.append( term)
grad_dict[var] = nonempty_sum(terms)
else:
#this variable is not connected to the cost in the computational
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论