提交 0103f893 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

moved some code shared between grad and grad_sources_inputs to a third

helper function
上级 3ea17d76
...@@ -534,7 +534,51 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -534,7 +534,51 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
grad_dict[elem] = elem.zeros_like() grad_dict[elem] = elem.zeros_like()
#build a dict mapping node to the terms node contributes to each of its inputs' gradients rval = _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type,
cost.name)
if using_tuple:
rval = tuple(rval)
elif not using_list:
rval ,= rval
return rval
def _populate_grad_dict(var_to_node_to_idx,\
grad_dict, wrt, warn_type, cost_name = None):
"""
Common code shared between grad_sources_inputs and grad
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
grad_dict: a dictionary mapping variables to their gradients
should be populated by grad or grad_sources_inputs
grad should set gradients to zeros_like for
variables to be considered constant, set the
gradient for the cost variable to g_cost, etc.
both should set the gradient for disconnected
inputs to zeros_like
wrt: the minimal set of variables that must be included in grad_dict
warn_type: if True, log a warning when a gradient term for a variable
has a different type from that variable
cost_name: The name of the cost being differentiated, optional.
used to name the grad with respect to x as (d<cost_name>/dx)
returns: a list of gradients corresponding to wrt
"""
#build a dict mapping node to the terms node contributes to each of
#its inputs' gradients
term_dict = {} term_dict = {}
#populate term_dict[node] and return it #populate term_dict[node] and return it
...@@ -603,8 +647,8 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -603,8 +647,8 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
terms.append( term) terms.append( term)
grad_dict[var] = nonempty_sum(terms) grad_dict[var] = nonempty_sum(terms)
if cost.name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else: else:
#this variable is not connected to the cost in the computational #this variable is not connected to the cost in the computational
#graph so the gradient on it is zero #graph so the gradient on it is zero
...@@ -614,10 +658,6 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -614,10 +658,6 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
rval = [ access_grad_cache(elem) for elem in wrt ] rval = [ access_grad_cache(elem) for elem in wrt ]
if using_tuple:
rval = tuple(rval)
elif not using_list:
rval ,= rval
return rval return rval
...@@ -637,7 +677,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True): ...@@ -637,7 +677,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True):
wrt = graph_inputs wrt = graph_inputs
#var_to_node_to_idx[var][node] = i means node has var as input at position i #var_to_node_to_idx[var][node] = i means node has var as input at position i
var_to_node_to_idx = {} var_to_node_to_idx = {}
#set of variables that have been added to their parents #set of variables that have been added to their parents
...@@ -678,72 +717,9 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True): ...@@ -678,72 +717,9 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True):
"only by a non-differentiable operator: %s" % elem) "only by a non-differentiable operator: %s" % elem)
grad_dict[elem] = elem.zeros_like() grad_dict[elem] = elem.zeros_like()
#build a dict mapping node to the terms node contributes to each of its inputs' gradients
term_dict = {}
#populate term_dict[node] and return it
def access_term_cache(node):
if node not in term_dict:
inputs = node.inputs
input_grads = node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs])
if input_grads is None:
raise TypeError("%s.grad returned NoneType, "
"expected iterable." % str(node.op))
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op))
#must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = node.inputs[i].zeros_like()
return term_dict[node]
#built-in python sum adds an extraneous TensorConstant(0)
#we can exploit the knowledge that iterable always has at
#least one element to avoid starting the sum at 0
def nonempty_sum( iterable ):
rval = iterable[0]
for elem in iterable[1:]:
rval = rval + elem
return rval
#populate grad_dict[var] and return it _populate_grad_dict(var_to_node_to_idx,
def access_grad_cache(var): grad_dict, wrt, warn_type)
if var not in grad_dict:
if var in var_to_node_to_idx:
terms = []
node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx:
idx = node_to_idx[node]
term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable):
raise TypeError("%s.grad returned %s, expected"
" Variable instance." % (str(node.op),
type(term)))
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan)
terms.append( term)
grad_dict[var] = nonempty_sum(terms)
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = var.zeros_like()
return grad_dict[var]
rval = [ access_grad_cache(elem) for elem in wrt ]
return grad_dict return grad_dict
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论