提交 ba4ba5ef authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix pep8 for gradient.py. (pep8 time 20 mins)

上级 1cb68ffc
...@@ -355,8 +355,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -355,8 +355,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient # Gradient
######################### #########################
def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs = 'raise', add_names = True): disconnected_inputs='raise', add_names=True):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
:type wrt: Variable or list of Variables. :type wrt: Variable or list of Variables.
...@@ -398,9 +398,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -398,9 +398,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
if cost.ndim != 0: if cost.ndim != 0:
raise TypeError("cost must be a scalar.") raise TypeError("cost must be a scalar.")
if isinstance(cost.type, NullType): if isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost. cost is NaN because "+\ raise ValueError("Can't differentiate a NaN cost."
"cost is NaN because " + \
cost.type.why_null) cost.type.why_null)
if consider_constant is None: if consider_constant is None:
...@@ -419,11 +419,10 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -419,11 +419,10 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
raise TypeError('Elements of consider_constant must be ' raise TypeError('Elements of consider_constant must be '
'variables, but got ' + str(type(elem))) 'variables, but got ' + str(type(elem)))
using_list = isinstance(wrt,list) using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt,tuple) using_tuple = isinstance(wrt, tuple)
if not using_list and not using_tuple: if not using_list and not using_tuple:
wrt = [ wrt ] wrt = [wrt]
var_to_node_to_idx = _populate_var_to_node_to_idx([cost]) var_to_node_to_idx = _populate_var_to_node_to_idx([cost])
...@@ -475,7 +474,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -475,7 +474,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
if using_tuple: if using_tuple:
rval = tuple(rval) rval = tuple(rval)
elif not using_list: elif not using_list:
rval ,= rval rval, = rval
return rval return rval
...@@ -494,8 +493,8 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -494,8 +493,8 @@ def _populate_var_to_node_to_idx(outputs):
""" """
#var_to_node_to_idx[var][node] = [i,j] means node has
#var_to_node_to_idx[var][node] = [i,j] means node has var as input at positions i and j #var as input at positions i and j
var_to_node_to_idx = {} var_to_node_to_idx = {}
#set of variables that have been added to their parents #set of variables that have been added to their parents
accounted_for = set([]) accounted_for = set([])
...@@ -527,8 +526,9 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -527,8 +526,9 @@ def _populate_var_to_node_to_idx(outputs):
return var_to_node_to_idx return var_to_node_to_idx
def _populate_grad_dict(var_to_node_to_idx,\ def _populate_grad_dict(var_to_node_to_idx,\
grad_dict, wrt, warn_type, cost_name = None): grad_dict, wrt, warn_type, cost_name=None):
""" """
Common code shared between grad_sources_inputs and grad Common code shared between grad_sources_inputs and grad
...@@ -554,7 +554,8 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -554,7 +554,8 @@ def _populate_grad_dict(var_to_node_to_idx,\
has a different type from that variable has a different type from that variable
cost_name: The name of the cost being differentiated, optional. cost_name: The name of the cost being differentiated, optional.
used to name the grad with respect to x as (d<cost_name>/dx) used to name the grad with respect to x as
(d<cost_name>/dx)
returns: a list of gradients corresponding to wrt returns: a list of gradients corresponding to wrt
...@@ -575,21 +576,22 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -575,21 +576,22 @@ def _populate_grad_dict(var_to_node_to_idx,\
# cycle. We avoid this cycle by passing (symbolic) copies of # cycle. We avoid this cycle by passing (symbolic) copies of
# each destroyed input. # each destroyed input.
try: try:
dinputs = [node.inputs[x[0]] for x in node.op.destroy_map.values()] dinputs = [node.inputs[x[0]] for x in
node.op.destroy_map.values()]
except AttributeError: except AttributeError:
dinputs = [] dinputs = []
def try_to_copy_if_needed(var): def try_to_copy_if_needed(var):
if var in dinputs and hasattr(var,'copy'): if var in dinputs and hasattr(var, 'copy'):
return var.copy() return var.copy()
return var return var
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs ] inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
output_grads = [ access_grad_cache(var) for var in node.outputs ] output_grads = [access_grad_cache(var) for var in node.outputs]
if False in [ isinstance(g.type, DisconnectedType) if False in [isinstance(g.type, DisconnectedType)
for g in output_grads ]: for g in output_grads]:
#Some outputs of this op are connected to the cost so we must #Some outputs of this op are connected to the cost so we must
#call the ops grad method #call the ops grad method
...@@ -600,15 +602,15 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -600,15 +602,15 @@ def _populate_grad_dict(var_to_node_to_idx,\
"expected iterable." % str(node.op)) "expected iterable." % str(node.op))
if len(input_grads) != len(inputs): if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\ raise ValueError(("%s returned the wrong number of" +\
"terms.") % str(node.op)) " gradient terms.") % str(node.op))
else: else:
#All outputs of this op are disconnected so we can skip #All outputs of this op are disconnected so we can skip
#Calling the op's grad method and report that the inputs #Calling the op's grad method and report that the inputs
#are disconnected #are disconnected
#(The op's grad method could do this too, but this saves the #(The op's grad method could do this too, but this saves the
#implementer the trouble of worrying about this case) #implementer the trouble of worrying about this case)
input_grads = [ DisconnectedType()() for ipt in inputs ] input_grads = [DisconnectedType()() for ipt in inputs]
#must convert to list in case the op returns a tuple #must convert to list in case the op returns a tuple
#we won't be able to post-process out the Nones if it does that #we won't be able to post-process out the Nones if it does that
...@@ -617,12 +619,17 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -617,12 +619,17 @@ def _populate_grad_dict(var_to_node_to_idx,\
for i in xrange(len(term_dict[node])): for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None: if term_dict[node][i] is None:
#we don't know what None means. in the past it has been used to #we don't know what None means. in the past it has been
#mean undefined, zero, or disconnected. So for now we assume it is #used to
#zero. Assuming it is zero prevents us from disconnecting NaNs above. #mean undefined, zero, or disconnected. So for now we
#eventually we should disallow this return type and force all ops #assume it is
#zero. Assuming it is zero prevents
#us from disconnecting NaNs above.
#eventually we should disallow this
#return type and force all ops
#to return the correct thing #to return the correct thing
#raise AssertionError('%s returned None for a gradient term, ' #raise AssertionError('%s returned None for' +\
# ' a gradient term, '
# 'this is prohibited' % node.op) # 'this is prohibited' % node.op)
term_dict[node][i] = node.inputs[i].zeros_like() term_dict[node][i] = node.inputs[i].zeros_like()
...@@ -630,17 +637,17 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -630,17 +637,17 @@ def _populate_grad_dict(var_to_node_to_idx,\
g_r_type = term_dict[node][i].type g_r_type = term_dict[node][i].type
r_type = inputs[i].type r_type = inputs[i].type
if g_r_type != r_type: if g_r_type != r_type:
_logger.warning('%s.grad returned a different type (%s) ' _logger.warning(
'%s.grad returned a different type (%s) '
'for input %i of type (%s)', 'for input %i of type (%s)',
node.op, g_r_type, i, r_type) node.op, g_r_type, i, r_type)
return term_dict[node] return term_dict[node]
#built-in python sum adds an extraneous TensorConstant(0) #built-in python sum adds an extraneous TensorConstant(0)
#we can exploit the knowledge that iterable always has at #we can exploit the knowledge that iterable always has at
#least one element to avoid starting the sum at 0 #least one element to avoid starting the sum at 0
def nonempty_sum( iterable ): def nonempty_sum(iterable):
rval = iterable[0] rval = iterable[0]
for elem in iterable[1:]: for elem in iterable[1:]:
rval = rval + elem rval = rval + elem
...@@ -667,28 +674,27 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -667,28 +674,27 @@ def _populate_grad_dict(var_to_node_to_idx,\
" Variable instance." % (str(node.op), " Variable instance." % (str(node.op),
type(term))) type(term)))
if isinstance(term.type,NullType): if isinstance(term.type, NullType):
raise TypeError("tensor.grad encountered a NaN. "+\ raise TypeError("tensor.grad "
"encountered a NaN. " +\
term.type.why_null) term.type.why_null)
terms.append( term) terms.append(term)
grad_dict[var] = nonempty_sum(terms) grad_dict[var] = nonempty_sum(terms)
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else: else:
#this variable is not connected to the cost in the computational #this variable isn't connected to the cost in the computational
#graph #graph
grad_dict[var] = DisconnectedType()() grad_dict[var] = DisconnectedType()()
return grad_dict[var] return grad_dict[var]
rval = [access_grad_cache(elem) for elem in wrt]
rval = [ access_grad_cache(elem) for elem in wrt ]
return rval return rval
def grad_sources_inputs(sources, graph_inputs, warn_type=True):
def grad_sources_inputs(sources, graph_inputs, warn_type = True):
""" """
Used to compute the gradient of a cost with respect to all the Used to compute the gradient of a cost with respect to all the
variables between graph_input and cost, but in the special variables between graph_input and cost, but in the special
...@@ -776,18 +782,18 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True): ...@@ -776,18 +782,18 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = True):
if elem not in var_to_node_to_idx and elem not in outputs: if elem not in var_to_node_to_idx and elem not in outputs:
grad_dict[elem] = DisconnectedType()() grad_dict[elem] = DisconnectedType()()
_populate_grad_dict(var_to_node_to_idx, _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type) grad_dict, wrt, warn_type)
#post-process out the DisconnectedTypes #post-process out the DisconnectedTypes
for key in grad_dict: for key in grad_dict:
if isinstance(grad_dict[key].type,DisconnectedType): if isinstance(grad_dict[key].type, DisconnectedType):
if hasattr(key,'zeros_like'): if hasattr(key, 'zeros_like'):
grad_dict[key] = key.zeros_like() grad_dict[key] = key.zeros_like()
return grad_dict return grad_dict
class numeric_grad(object): class numeric_grad(object):
""" """
Compute the numeric derivative of a scalar-valued function at a particular Compute the numeric derivative of a scalar-valued function at a particular
...@@ -986,6 +992,7 @@ class numeric_grad(object): ...@@ -986,6 +992,7 @@ class numeric_grad(object):
max_pos = pos[max_arg] max_pos = pos[max_arg]
return (max_arg, pos[max_arg], abs_errs[max_arg], rel_errs[max_arg]) return (max_arg, pos[max_arg], abs_errs[max_arg], rel_errs[max_arg])
def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
out_type=None, abs_tol=None, out_type=None, abs_tol=None,
rel_tol=None, mode=None, cast_to_output_type=False): rel_tol=None, mode=None, cast_to_output_type=False):
...@@ -1119,7 +1126,6 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1119,7 +1126,6 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
symbolic_grad = grad(cost, tensor_pt, g_cost, symbolic_grad = grad(cost, tensor_pt, g_cost,
disconnected_inputs='ignore') disconnected_inputs='ignore')
grad_fn = function(tensor_pt, symbolic_grad) grad_fn = function(tensor_pt, symbolic_grad)
for test_num in xrange(n_tests): for test_num in xrange(n_tests):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论