提交 3a857acb authored 作者: Ian Goodfellow's avatar Ian Goodfellow

changed from marking variables directly to making a dictionary mapping

variables to marks
上级 fb9cb2f3
...@@ -481,120 +481,116 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -481,120 +481,116 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
if not using_list and not using_tuple: if not using_list and not using_tuple:
wrt = [ wrt ] wrt = [ wrt ]
#set of variables that has had children added to it #var_to_node_to_idx[var][node] = i means node has var as input at position i
marked = set([]) var_to_node_to_idx = {}
#set of variables that have been added to their parents #set of variables that have been added to their parents
accounted_for = set([]) accounted_for = set([])
#use a try/finally to make sure we don't leave any marks #mark the variables in the relevant subgraph with
#on the variables #a dictionary called chidlren
try: #var._children[node] gives the index of var in _children.inputs
#mark the variables in the relevant subgraph with def account_for(var):
#a dictionary called chidlren if var in accounted_for:
#var._children[node] gives the index of var in _children.inputs return
def account_for(var): accounted_for.add(var)
if var in accounted_for: if var.owner is not None:
return node = var.owner
accounted_for.add(var) for i, ipt in enumerate(node.inputs):
if var.owner is not None: if ipt not in var_to_node_to_idx:
node = var.owner var_to_node_to_idx[ipt] = {}
for i, ipt in enumerate(node.inputs): var_to_node_to_idx[ipt][node] = i
if not hasattr(ipt, '_children'): account_for(ipt)
marked.add(ipt)
ipt._children = {} account_for(cost)
if node not in ipt._children:
ipt._children[node] = i #build a dict mapping var to the gradient of cost with respect to var
account_for(ipt) grad_dict = {}
#by default, the gradient of the cost is 1
account_for(cost) if g_cost is None:
g_cost = tensor.ones_like(cost)
#build a dict mapping var to the gradient of cost with respect to var grad_dict[cost] = g_cost
grad_dict = {}
#by default, the gradient of the cost is 1 #the gradient of the constants is 0
if g_cost is None: for const in consider_constant:
g_cost = tensor.ones_like(cost) grad_dict[const] = tensor.zeros_like(const)
grad_dict[cost] = g_cost
#variables that do not influence the cost have zero gradient.
#the gradient of the constants is 0 #if wrt is such a variable, populate the grad_dict with this info
for const in consider_constant: #so that wrt not being in var_to_node_to_idx won't cause an error below
grad_dict[const] = tensor.zeros_like(const) #according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
#variables that do not influence the cost have zero gradient. if elem not in var_to_node_to_idx and elem is not cost:
#if wrt is such a varibale, populate the grad_dict with this info message = ("grad method was asked to compute the gradient "
#so that wrt not having _children won't cause an error below "with respect to a variable that is not part of "
#according to the flag, possibly raise an error if wrt is disconnected "the computational graph of the cost, or is used "
for elem in wrt: "only by a non-differentiable operator: %s" % elem)
if elem not in marked and elem is not cost: if disconnected_inputs == 'ignore':
message = ("grad method was asked to compute the gradient " pass
"with respect to a variable that is not part of " elif disconnected_inputs == 'warn':
"the computational graph of the cost, or is used " warnings.warn(message, stacklevel=1)
"only by a non-differentiable operator: %s" % elem) elif disconnected_inputs == 'raise':
if disconnected_inputs == 'ignore': raise ValueError(message)
pass else:
elif disconnected_inputs == 'warn': raise ValueError("Invalid value for keyword "
warnings.warn(message, stacklevel=1) "'disconnected_inputs', valid values are "
elif disconnected_inputs == 'raise': "'ignore', 'warn' and 'raise'.")
raise ValueError(message) grad_dict[elem] = elem.zeros_like()
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
grad_dict[elem] = elem.zeros_like()
#build a dict mapping node to the terms node contributes to each of its inputs' gradients
term_dict = {}
#populate term_dict[node] and return it
def access_term_cache(node):
if node not in term_dict:
#must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs]))
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
return term_dict[node]
#built-in python sum adds an extraneous TensorConstant(0)
#we can exploit the knowledge that iterable always has at
#least one element to avoid starting the sum at 0
def nonempty_sum( iterable ):
rval = iterable[0]
for elem in iterable[1:]:
rval = rval + elem
return rval
#populate grad_dict[var] and return it #build a dict mapping node to the terms node contributes to each of its inputs' gradients
def access_grad_cache(var): term_dict = {}
if var not in grad_dict:
if hasattr(var,'_children'): #populate term_dict[node] and return it
terms = [] def access_term_cache(node):
for child in var._children.keys(): if node not in term_dict:
idx = var._children[child]
term = access_term_cache(child)[idx]
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan)
terms.append( term)
grad_dict[var] = nonempty_sum(terms)
if cost.name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name)
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
return grad_dict[var]
inputs = node.inputs
output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads)
#must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
return term_dict[node]
#built-in python sum adds an extraneous TensorConstant(0)
#we can exploit the knowledge that iterable always has at
#least one element to avoid starting the sum at 0
def nonempty_sum( iterable ):
rval = iterable[0]
for elem in iterable[1:]:
rval = rval + elem
return rval
#populate grad_dict[var] and return it
def access_grad_cache(var):
if var not in grad_dict:
if var in var_to_node_to_idx:
terms = []
node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx:
idx = node_to_idx[node]
term = access_term_cache(node)[idx]
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan)
terms.append( term)
grad_dict[var] = nonempty_sum(terms)
if cost.name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name)
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
return grad_dict[var]
rval = [ access_grad_cache(elem) for elem in wrt ]
finally: rval = [ access_grad_cache(elem) for elem in wrt ]
#take the marks out
for node in marked:
del node._children
if using_tuple: if using_tuple:
rval = tuple(rval) rval = tuple(rval)
...@@ -614,105 +610,95 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -614,105 +610,95 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
wrt = graph_inputs wrt = graph_inputs
#set of variables that has had children added to it #var_to_node_to_idx[var][node] = i means node has var as input at position i
marked = set([]) var_to_node_to_idx = {}
#set of variables that have been added to their parents #set of variables that have been added to their parents
accounted_for = set([]) accounted_for = set([])
#use a try/finally to make sure we don't leave any marks #notify the parents the variables in the relevant subgraph
#on the variables #that they have children
try: def account_for(var):
#mark the variables in the relevant subgraph with if var in accounted_for:
#a dictionary called chidlren return
#var._children[node] gives the index of var in _children.inputs accounted_for.add(var)
def account_for(var): if var.owner is not None:
if var in accounted_for: node = var.owner
return for i, ipt in enumerate(node.inputs):
accounted_for.add(var) if ipt not in var_to_node_to_idx:
if var.owner is not None: var_to_node_to_idx[ipt] = {}
node = var.owner var_to_node_to_idx[ipt][node] = i
for i, ipt in enumerate(node.inputs): account_for(ipt)
if not hasattr(ipt, '_children'):
marked.add(ipt) for output in outputs:
ipt._children = {} account_for(output)
if node not in ipt._children:
ipt._children[node] = i #build a dict mapping var to the gradient of cost with respect to var
account_for(ipt) grad_dict = {}
#by default, the gradient of the cost is 1
for output in outputs: for output, output_grad in sources:
account_for(output) grad_dict[output] = output_grad
#build a dict mapping var to the gradient of cost with respect to var #variables that do not influence the cost have zero gradient.
grad_dict = {} #if wrt is such a variable, populate the grad_dict with this info
#by default, the gradient of the cost is 1 #so that wrt not being in var_to_node_to_idx won't cause an error below
for output, output_grad in sources: #according to the flag, possibly raise an error if wrt is disconnected
grad_dict[output] = output_grad for elem in wrt:
if elem not in var_to_node_to_idx and elem not in outputs:
#variables that do not influence the cost have zero gradient. message = ("grad method was asked to compute the gradient "
#if wrt is such a varibale, populate the grad_dict with this info "with respect to a variable that is not part of "
#so that wrt not having _children won't cause an error below "the computational graph of the cost, or is used "
#according to the flag, possibly raise an error if wrt is disconnected "only by a non-differentiable operator: %s" % elem)
for elem in wrt: grad_dict[elem] = elem.zeros_like()
if elem not in marked and elem not in outputs:
message = ("grad method was asked to compute the gradient " #build a dict mapping node to the terms node contributes to each of its inputs' gradients
"with respect to a variable that is not part of " term_dict = {}
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % elem) #populate term_dict[node] and return it
#raise ValueError(message) def access_term_cache(node):
grad_dict[elem] = elem.zeros_like() if node not in term_dict:
#must convert to list in case the op returns a tuple
#build a dict mapping node to the terms node contributes to each of its inputs' gradients #we won't be able to post-process out the Nones if it does that
term_dict = {} term_dict[node] = list(node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs]))
#populate term_dict[node] and return it for i in xrange(len(term_dict[node])):
def access_term_cache(node): if term_dict[node][i] is None:
if node not in term_dict: term_dict[node][i] = tensor.zeros_like(node.inputs[i])
#must convert to list in case the op returns a tuple return term_dict[node]
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs])) #built-in python sum adds an extraneous TensorConstant(0)
for i in xrange(len(term_dict[node])): #we can exploit the knowledge that iterable always has at
if term_dict[node][i] is None: #least one element to avoid starting the sum at 0
term_dict[node][i] = tensor.zeros_like(node.inputs[i]) def nonempty_sum( iterable ):
return term_dict[node] rval = iterable[0]
for elem in iterable[1:]:
rval = rval + elem
#built-in python sum adds an extraneous TensorConstant(0) return rval
#we can exploit the knowledge that iterable always has at
#least one element to avoid starting the sum at 0 #populate grad_dict[var] and return it
def nonempty_sum( iterable ): def access_grad_cache(var):
rval = iterable[0] if var not in grad_dict:
for elem in iterable[1:]: if var in var_to_node_to_idx:
rval = rval + elem terms = []
return rval node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx:
#populate grad_dict[var] and return it idx = node_to_idx[node]
def access_grad_cache(var): term = access_term_cache(node)[idx]
if var not in grad_dict:
if hasattr(var,'_children'): if isinstance(term.type,NaNType):
terms = [] raise TypeError("tensor.grad encountered a NaN. "+\
for child in var._children.keys(): term.type.why_nan)
idx = var._children[child]
term = access_term_cache(child)[idx] terms.append( term)
grad_dict[var] = nonempty_sum(terms)
if isinstance(term.type,NaNType): else:
raise TypeError("tensor.grad encountered a NaN. "+\ #this variable is not connected to the cost in the computational
term.type.why_nan) #graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
terms.append( term) return grad_dict[var]
grad_dict[var] = nonempty_sum(terms)
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
return grad_dict[var]
rval = [ access_grad_cache(elem) for elem in wrt ] rval = [ access_grad_cache(elem) for elem in wrt ]
finally:
#take the marks out
for node in marked:
del node._children
return grad_dict return grad_dict
...@@ -1149,6 +1135,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1149,6 +1135,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
return plain return plain
t_r = shared(random_projection()) t_r = shared(random_projection())
t_r.name = 'random_projection'
# random projection of o onto t_r # random projection of o onto t_r
# This sum() is defined above, it's not the builtin sum. # This sum() is defined above, it's not the builtin sum.
...@@ -1178,6 +1165,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1178,6 +1165,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
num_grad.max_err(analytic_grad, abs_tol, rel_tol) num_grad.max_err(analytic_grad, abs_tol, rel_tol)
if max_abs_err > abs_tol and max_rel_err > rel_tol: if max_abs_err > abs_tol and max_rel_err > rel_tol:
raise verify_grad.E_grad(max_arg, max_err_pos, raise verify_grad.E_grad(max_arg, max_err_pos,
max_abs_err, max_rel_err, abs_tol, rel_tol) max_abs_err, max_rel_err, abs_tol, rel_tol)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论