提交 d1e926cb authored 作者: Ian Goodfellow's avatar Ian Goodfellow

undo bad PR that violated the op contract

上级 5ea4d31c
......@@ -249,6 +249,8 @@ following methods:
1) They must be Variable instances.
2) When they are types that have dtypes, they must never have an integer dtype.
The output gradients passed *to* Op.grad will also obey these constraints.
Integers are a tricky subject. Integers are the main reason for having DisconnectedType,
NullType or zero gradient. When you have an integer as an argument to your grad method,
recall the definition of a derivative to help you decide what value to return:
......
......@@ -444,38 +444,18 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype.
if cost is not None:
if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes:
if g_cost is not None:
try:
cval = theano.get_constant_value(g_cost)
if cval == 0:
g_cost_is_zero = True
else:
g_cost_is_zero = False
except TypeError:
g_cost_is_zero = False
if not g_cost_is_zero:
raise ValueError(
"The gradient of a cost of non-continuous dtype "
"(here, %s), if it is defined, should be 0. "
"However, a value of %s was provided in the "
"'g_cost' "
"argument of theano.grad(). To remove this error,"
" you can simply omit the 'g_cost' argument, or "
"give it the default value of None." % (
getattr(g_cost.type, 'dtype', 'no dtype defined'),
g_cost))
g_cost = tensor.zeros_like(cost)
elif g_cost is None:
# cost.type.dtype is in tensor.float_dtypes at that point
g_cost = tensor.ones_like(cost)
else:
# Cast the provided gradient so that it has the same dtype
# as the cost.
g_cost = g_cost.astype(cost.type.dtype)
if g_cost is None:
g_cost = _float_ones_like(cost)
# g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try
# to cast it make sure it even has a dtype
if hasattr(g_cost, 'dtype') and cost.dtype not in tensor.discrete_dtypes:
# Here we enforce the constraint that floating point variables have
# the same dtype as their gradient.
g_cost = g_cost.astype(cost.dtype)
# DO NOT enforce g_cost to be 0 if cost is an integer.
# This is to be enforced by the Op.grad method for the Op that outputs cost.
assert g_cost not in tensor.discrete_dtypes
grad_dict[cost] = g_cost
else:
......@@ -495,7 +475,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
if g_var.type not in [NullType, DisconnectedType] and 'float' \
not in str(g_var.type.dtype):
raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous.")
"DisconnectedType, or continuous, but grad was "
"given a known_grad of type "+str(g_var.type))
# DO NOT check that these gradients are equal to 0 if var is int
# The gradient is allowed to be non-zero on var in that case
# Ops outputing var should not backpropagate its gradient further
# but that is enforced elsewhere (grep for only_connected_to_int)
grad_dict[var] = g_var
......@@ -529,11 +515,11 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
cost_name = cost.name
# Make sure we didn't initialize the grad_dict with any ints
# for non-int outputs
# The gradient may NEVER be an int, even if the variable is an int.
# Read the Op contract and talk to Ian Goodfellow before changing this!
for var in grad_dict:
g = grad_dict[var]
if (hasattr(g.type, 'dtype') and
getattr(var.type, 'dtype', '') in tensor.float_dtypes):
if hasattr(g.type, 'dtype'):
assert g.type.dtype in tensor.float_dtypes
rval = _populate_grad_dict(var_to_node_to_idx,
......@@ -816,38 +802,46 @@ def _populate_grad_dict(var_to_node_to_idx,
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
# Build a list of output gradients with the same dtype as
# the corresponding output variable.
# If an output is of a float dtype, we want to cast the
# output gradient into the same dtype, to avoid having a
# gradient graph with double precision (taking more memory,
# and more computation).
# If an output is of an integer dtype, then we ensure the
# output gradient is zero, and that zero can be represented
# in the same int dtype.
# If an output gradient is a NullType or DisconnectedType,
# then it will not have a dtype, and it will not be changed.
# If an output is of an integer dtype, then we just leave it
# alone.
# DO NOT force integer variables to have zero grad. This causes
# bugs where we fail to detect disconected or undefined gradients.
# DO NOT force integer variables to have integer dtype. This is
# a violation of the op contract.
new_output_grads = []
for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None)
og_dt = getattr(og.type, 'dtype', None)
if og_dt and o_dt in theano.tensor.discrete_dtypes:
new_output_grads.append(o.zeros_like())
elif o_dt and og_dt and o_dt != og_dt:
if o_dt not in theano.tensor.discrete_dtypes and og_dt and o_dt != og_dt:
new_output_grads.append(og.astype(o_dt))
else:
new_output_grads.append(og)
# Make sure that, if new_output_grads[i] has a dtype:
# - it is the same dtype as outputs[i]
# - if the dtype is an int, then new_output_grads[i] is 0.
# Make sure that, if new_output_grads[i] has a floating point dtype,
# it is the same dtype as outputs[i]
for o, ng in zip(node.outputs, new_output_grads):
o_dt = getattr(o.type, 'dtype', None)
ng_dt = getattr(ng.type, 'dtype', None)
if ng_dt:
if ng_dt is not None and o_dt not in theano.tensor.discrete_dtypes:
assert ng_dt == o_dt
if ng_dt in theano.tensor.discrete_dtypes:
assert theano.get_constant_value(ng) == 0
# Someone who had obviously not read the Op contract tried
# to modify this part of the function.
# If you ever think it is a good idea to make an integer
# valued gradient, please
# 1) Read the Op contract again
# 2) Talk to Ian Goodfellow
# (Both of these sources will tell you not to do it)
for ng in new_output_grads:
assert getattr(ng.type, 'dtype', None) not in theano.tensor.discrete_dtypes
input_grads = node.op.grad(inputs, new_output_grads)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论