提交 c2b3e4fa authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

gradient.py in pep8

上级 020fc625
...@@ -467,16 +467,17 @@ def grad(cost, wrt, consider_constant=None, ...@@ -467,16 +467,17 @@ def grad(cost, wrt, consider_constant=None,
g_cost = known_grads[cost] g_cost = known_grads[cost]
else: else:
g_cost = _float_ones_like(cost) g_cost = _float_ones_like(cost)
# g_cost may be Disconnected or NullType. A creative use of the function, # g_cost may be Disconnected or NullType. A creative use of the
# sure, but nonetheless one we can and should support. So before we try # function, sure, but nonetheless one we can and should support.
# to cast it make sure it even has a dtype # So before we try to cast it make sure it even has a dtype
if (hasattr(g_cost.type, 'dtype') and if (hasattr(g_cost.type, 'dtype') and
cost.type.dtype not in tensor.discrete_dtypes): cost.type.dtype not in tensor.discrete_dtypes):
# Here we enforce the constraint that floating point variables have # Here we enforce the constraint that floating point variables
# the same dtype as their gradient. # have the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
# DO NOT enforce g_cost to be 0 if cost is an integer. # DO NOT enforce g_cost to be 0 if cost is an integer.
# This is to be enforced by the Op.grad method for the Op that outputs cost. # This is to be enforced by the Op.grad method for the
# Op that outputs cost.
if hasattr(g_cost.type, 'dtype'): if hasattr(g_cost.type, 'dtype'):
assert g_cost.type.dtype not in tensor.discrete_dtypes assert g_cost.type.dtype not in tensor.discrete_dtypes
...@@ -494,7 +495,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -494,7 +495,7 @@ def grad(cost, wrt, consider_constant=None,
'float' not in str(g_var.type.dtype)): 'float' not in str(g_var.type.dtype)):
raise TypeError("Gradients must always be NullType, " raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous, but grad was " "DisconnectedType, or continuous, but grad was "
"given a known_grad of type "+str(g_var.type)) "given a known_grad of type " + str(g_var.type))
# DO NOT check that these gradients are equal to 0 if var is int # DO NOT check that these gradients are equal to 0 if var is int
# The gradient is allowed to be non-zero on var in that case # The gradient is allowed to be non-zero on var in that case
...@@ -846,10 +847,10 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -846,10 +847,10 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
if ipt not in var_to_app_to_idx: if ipt not in var_to_app_to_idx:
# This object here *must* be an OrderedDict, because # This object here *must* be an OrderedDict, because
# we iterate over its keys when adding up the terms of # we iterate over its keys when adding up the terms of the
# the gradient on ipt. If it is a regular dict, the grad # gradient on ipt. If it is a regular dict, the grad method
# method will return something that is analytically correct, # will return something that is analytically correct, but
# but whose order of doing additions depends on the memory # whose order of doing additions depends on the memory
# location of the apply nodes. # location of the apply nodes.
var_to_app_to_idx[ipt] = OrderedDict() var_to_app_to_idx[ipt] = OrderedDict()
app_to_idx = var_to_app_to_idx[ipt] app_to_idx = var_to_app_to_idx[ipt]
...@@ -923,8 +924,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -923,8 +924,8 @@ def _populate_grad_dict(var_to_app_to_idx,
grad_dict: A dictionary mapping variables to their gradients. grad_dict: A dictionary mapping variables to their gradients.
Should be populated by grad function, which should: Should be populated by grad function, which should:
-Set the gradient with respect to the cost to 1 -Set the gradient with respect to the cost to 1
-Load all gradients from known_grads, possibly overriding -Load all gradients from known_grads, possibly
the cost overriding the cost
-Set the gradient for disconnected -Set the gradient for disconnected
inputs to a variable with type DisconnectedType() inputs to a variable with type DisconnectedType()
...@@ -1004,10 +1005,10 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1004,10 +1005,10 @@ def _populate_grad_dict(var_to_app_to_idx,
# call the op's grad method # call the op's grad method
# Each Op's grad function requires inputs and output_grads # Each Op's grad function requires inputs and output_grads
# If the Op destroys any input, but the grad expression uses it, # If the Op destroys any input, but the grad expression uses
# then chances are the resulting graph will have a dependency # it, then chances are the resulting graph will have a
# cycle. We avoid this cycle by passing (symbolic) copies of # dependency cycle. We avoid this cycle by passing (symbolic)
# each destroyed input. # copies of each destroyed input.
try: try:
dinputs = [node.inputs[x[0]] for x in dinputs = [node.inputs[x[0]] for x in
node.op.destroy_map.values()] node.op.destroy_map.values()]
...@@ -1030,9 +1031,10 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1030,9 +1031,10 @@ def _populate_grad_dict(var_to_app_to_idx,
# If an output is of an integer dtype, then we just leave it # If an output is of an integer dtype, then we just leave it
# alone. # alone.
# DO NOT force integer variables to have zero grad. This causes # DO NOT force integer variables to have zero grad. This causes
# bugs where we fail to detect disconnected or undefined gradients. # bugs where we fail to detect disconnected or undefined
# DO NOT force integer variables to have integer dtype. This is # gradients.
# a violation of the op contract. # DO NOT force integer variables to have integer dtype.
# This is a violation of the op contract.
new_output_grads = [] new_output_grads = []
for o, og in zip(node.outputs, output_grads): for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None) o_dt = getattr(o.type, 'dtype', None)
...@@ -1063,12 +1065,13 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1063,12 +1065,13 @@ def _populate_grad_dict(var_to_app_to_idx,
assert (getattr(ng.type, 'dtype', None) assert (getattr(ng.type, 'dtype', None)
not in theano.tensor.discrete_dtypes) not in theano.tensor.discrete_dtypes)
# If config.compute_test_value is turned on, check that the gradients # If config.compute_test_value is turned on, check that the
# on the outputs of this node have the right shape. # gradients on the outputs of this node have the right shape.
# We also check the gradient on the inputs later--both checks are needed, # We also check the gradient on the inputs later--both checks
# because some gradients are only ever specified by the user, not computed # are needed, because some gradients are only ever specified
# by Op.grad, and some gradients are only computed and returned, but never # by the user, not computed by Op.grad, and some gradients are
# passed as another node's output grads. # only computed and returned, but never passed as another
# node's output grads.
for idx, packed in enumerate(izip(node.outputs, for idx, packed in enumerate(izip(node.outputs,
new_output_grads)): new_output_grads)):
orig_output, new_output_grad = packed orig_output, new_output_grad = packed
...@@ -1104,7 +1107,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1104,7 +1107,7 @@ def _populate_grad_dict(var_to_app_to_idx,
# raise ValueError( # raise ValueError(
# "%s returned the wrong type for gradient terms." # "%s returned the wrong type for gradient terms."
# " Sparse inputs must have sparse grads and dense" # " Sparse inputs must have sparse grads and dense"
# " inputs must have dense grad. Got %s, expected %s" % ( # " inputs must have dense grad. Got %s, expected %s" %(
# str(node.op), ig.type, i.type)) # str(node.op), ig.type, i.type))
# must convert to list in case the op returns a tuple # must convert to list in case the op returns a tuple
...@@ -1138,7 +1141,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1138,7 +1141,8 @@ def _populate_grad_dict(var_to_app_to_idx,
'the grad_undefined or grad_unimplemented helper ' 'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op) 'functions.') % node.op)
# Check that the gradient term for this input has the right shape # Check that the gradient term for this input
# has the right shape
if hasattr(term, 'shape'): if hasattr(term, 'shape'):
orig_ipt = inputs[i] orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term): for orig_ipt_v, term_v in get_debug_values(orig_ipt, term):
...@@ -1389,7 +1393,8 @@ class numeric_grad(object): ...@@ -1389,7 +1393,8 @@ class numeric_grad(object):
# create un-initialized memory # create un-initialized memory
x = numpy.ndarray((total_size,), dtype=working_dtype) x = numpy.ndarray((total_size,), dtype=working_dtype)
if (not out_type is None) and (out_type.startswith('complex')): # (not out_type is None) --> (out_type is not None) ???
if (out_type is not None) and (out_type.startswith('complex')):
gx = numpy.ndarray((total_size,), dtype=out_type) gx = numpy.ndarray((total_size,), dtype=out_type)
else: else:
gx = numpy.ndarray((total_size,), dtype=working_dtype) gx = numpy.ndarray((total_size,), dtype=working_dtype)
...@@ -1974,6 +1979,7 @@ def disconnected_grad(x): ...@@ -1974,6 +1979,7 @@ def disconnected_grad(x):
class GradClip(ViewOp): class GradClip(ViewOp):
# See doc in user fct grad_clip # See doc in user fct grad_clip
__props__ = () __props__ = ()
def __init__(self, clip_lower_bound, clip_upper_bound): def __init__(self, clip_lower_bound, clip_upper_bound):
# We do not put those member in __eq__ or __hash__ # We do not put those member in __eq__ or __hash__
# as they do not influence the perform of this op. # as they do not influence the perform of this op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论