提交 3a857acb authored 作者: Ian Goodfellow's avatar Ian Goodfellow

changed from marking variables directly to making a dictionary mapping

variables to marks
上级 fb9cb2f3
......@@ -481,14 +481,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
if not using_list and not using_tuple:
wrt = [ wrt ]
#set of variables that has had children added to it
marked = set([])
#var_to_node_to_idx[var][node] = i means node has var as input at position i
var_to_node_to_idx = {}
#set of variables that have been added to their parents
accounted_for = set([])
#use a try/finally to make sure we don't leave any marks
#on the variables
try:
#mark the variables in the relevant subgraph with
#a dictionary called chidlren
#var._children[node] gives the index of var in _children.inputs
......@@ -499,11 +496,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
if var.owner is not None:
node = var.owner
for i, ipt in enumerate(node.inputs):
if not hasattr(ipt, '_children'):
marked.add(ipt)
ipt._children = {}
if node not in ipt._children:
ipt._children[node] = i
if ipt not in var_to_node_to_idx:
var_to_node_to_idx[ipt] = {}
var_to_node_to_idx[ipt][node] = i
account_for(ipt)
account_for(cost)
......@@ -520,11 +515,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
grad_dict[const] = tensor.zeros_like(const)
#variables that do not influence the cost have zero gradient.
#if wrt is such a varibale, populate the grad_dict with this info
#so that wrt not having _children won't cause an error below
#if wrt is such a variable, populate the grad_dict with this info
#so that wrt not being in var_to_node_to_idx won't cause an error below
#according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in marked and elem is not cost:
if elem not in var_to_node_to_idx and elem is not cost:
message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
......@@ -547,10 +542,14 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
#populate term_dict[node] and return it
def access_term_cache(node):
if node not in term_dict:
inputs = node.inputs
output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads)
#must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that
term_dict[node] = list(node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs]))
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
......@@ -569,11 +568,12 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
#populate grad_dict[var] and return it
def access_grad_cache(var):
if var not in grad_dict:
if hasattr(var,'_children'):
if var in var_to_node_to_idx:
terms = []
for child in var._children.keys():
idx = var._children[child]
term = access_term_cache(child)[idx]
node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx:
idx = node_to_idx[node]
term = access_term_cache(node)[idx]
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
......@@ -591,10 +591,6 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
rval = [ access_grad_cache(elem) for elem in wrt ]
finally:
#take the marks out
for node in marked:
del node._children
if using_tuple:
rval = tuple(rval)
......@@ -614,17 +610,13 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
wrt = graph_inputs
#set of variables that has had children added to it
marked = set([])
#var_to_node_to_idx[var][node] = i means node has var as input at position i
var_to_node_to_idx = {}
#set of variables that have been added to their parents
accounted_for = set([])
#use a try/finally to make sure we don't leave any marks
#on the variables
try:
#mark the variables in the relevant subgraph with
#a dictionary called chidlren
#var._children[node] gives the index of var in _children.inputs
#notify the parents the variables in the relevant subgraph
#that they have children
def account_for(var):
if var in accounted_for:
return
......@@ -632,11 +624,9 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
if var.owner is not None:
node = var.owner
for i, ipt in enumerate(node.inputs):
if not hasattr(ipt, '_children'):
marked.add(ipt)
ipt._children = {}
if node not in ipt._children:
ipt._children[node] = i
if ipt not in var_to_node_to_idx:
var_to_node_to_idx[ipt] = {}
var_to_node_to_idx[ipt][node] = i
account_for(ipt)
for output in outputs:
......@@ -649,16 +639,15 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
grad_dict[output] = output_grad
#variables that do not influence the cost have zero gradient.
#if wrt is such a varibale, populate the grad_dict with this info
#so that wrt not having _children won't cause an error below
#if wrt is such a variable, populate the grad_dict with this info
#so that wrt not being in var_to_node_to_idx won't cause an error below
#according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in marked and elem not in outputs:
if elem not in var_to_node_to_idx and elem not in outputs:
message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % elem)
#raise ValueError(message)
grad_dict[elem] = elem.zeros_like()
#build a dict mapping node to the terms node contributes to each of its inputs' gradients
......@@ -689,11 +678,12 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
#populate grad_dict[var] and return it
def access_grad_cache(var):
if var not in grad_dict:
if hasattr(var,'_children'):
if var in var_to_node_to_idx:
terms = []
for child in var._children.keys():
idx = var._children[child]
term = access_term_cache(child)[idx]
node_to_idx = var_to_node_to_idx[var]
for node in node_to_idx:
idx = node_to_idx[node]
term = access_term_cache(node)[idx]
if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\
......@@ -709,10 +699,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
rval = [ access_grad_cache(elem) for elem in wrt ]
finally:
#take the marks out
for node in marked:
del node._children
return grad_dict
......@@ -1149,6 +1135,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
return plain
t_r = shared(random_projection())
t_r.name = 'random_projection'
# random projection of o onto t_r
# This sum() is defined above, it's not the builtin sum.
......@@ -1178,6 +1165,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
num_grad.max_err(analytic_grad, abs_tol, rel_tol)
if max_abs_err > abs_tol and max_rel_err > rel_tol:
raise verify_grad.E_grad(max_arg, max_err_pos,
max_abs_err, max_rel_err, abs_tol, rel_tol)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论