提交 5f75d4a0 authored 作者: lamblin's avatar lamblin

Merge pull request #1019 from lamblin/grad_downcast

Re-add part of the dtype constraint on out grads
......@@ -465,9 +465,41 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
# by default, the gradient of the cost is 1
if g_cost is None:
g_cost = _float_ones_like(cost)
# The gradient of the cost should default to 1 if the cost is of a
# continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype.
if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes:
if g_cost is not None:
try:
cval = theano.get_constant_value(g_cost)
if cval == 0:
g_cost_is_zero = True
else:
g_cost_is_zero = False
except TypeError:
g_cost_is_zero = False
if not g_cost_is_zero:
raise ValueError("The gradient of a cost of non-continuous "
"dtype (here, %s), if it is defined, should be 0. "
"However, a value of %s was provided in the 'g_cost' "
"argument of theano.grad(). To remove this error, "
"you can simply omit the 'g_cost' argument, or "
"give it the default value of None." % (
getattr(g_cost.type, 'dtype', 'no dtype defined'),
g_cost))
g_cost = tensor.zeros_like(cost)
elif g_cost is None:
# cost.type.dtype is in tensor.float_dtypes at that point
g_cost = tensor.ones_like(cost)
else:
# Cast the provided gradient so that it has the same dtype
# as the cost.
g_cost = g_cost.astype(cost.type.dtype)
grad_dict[cost] = g_cost
# the gradient of the constants is 0
......@@ -501,10 +533,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
cost_name = cost.name
# Make sure we didn't initialize the grad_dict with any ints
# for non-int outputs
for var in grad_dict:
g = grad_dict[var]
if hasattr(g.type, 'dtype'):
assert g.type.dtype.find('float') != -1
if (hasattr(g.type, 'dtype') and
getattr(var.type, 'dtype', '') in tensor.float_dtypes):
assert g.type.dtype in tensor.float_dtypes
rval = _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name)
......@@ -739,7 +773,40 @@ def _populate_grad_dict(var_to_node_to_idx,
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
input_grads = node.op.grad(inputs, output_grads)
# Build a list of output gradients with the same dtype as
# the corresponding output variable.
# If an output is of a float dtype, we want to cast the
# output gradient into the same dtype, to avoid having a
# gradient graph with double precision (taking more memory,
# and more computation).
# If an output is of an integer dtype, then we ensure the
# output gradient is zero, and that zero can be represented
# in the same int dtype.
# If an output gradient is a NullType or DisconnectedType,
# then it will not have a dtype, and it will not be changed.
new_output_grads = []
for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None)
og_dt = getattr(og.type, 'dtype', None)
if og_dt and o_dt in theano.tensor.discrete_dtypes:
new_output_grads.append(o.zeros_like())
elif o_dt and og_dt and o_dt != og_dt:
new_output_grads.append(og.astype(o_dt))
else:
new_output_grads.append(og)
# Make sure that, if new_output_grads[i] has a dtype:
# - it is the same dtype as outputs[i]
# - if the dtype is an int, then new_output_grads[i] is 0.
for o, ng in zip(node.outputs, new_output_grads):
o_dt = getattr(o.type, 'dtype', None)
ng_dt = getattr(ng.type, 'dtype', None)
if ng_dt:
assert ng_dt == o_dt
if ng_dt in theano.tensor.discrete_dtypes:
assert theano.get_constant_value(ng) == 0
input_grads = node.op.grad(inputs, new_output_grads)
if input_grads is None:
raise TypeError("%s.grad returned NoneType, "
......@@ -764,7 +831,7 @@ def _populate_grad_dict(var_to_node_to_idx,
#List of bools indicating if each output is an integer dtype
output_is_int = [hasattr(output.type, 'dtype') and
output.type.dtype.find('int') != -1
output.type.dtype in theano.tensor.discrete_dtypes
for output in node.outputs]
#List of bools indicating if each input only has integer outputs
......@@ -792,7 +859,7 @@ def _populate_grad_dict(var_to_node_to_idx,
if not isinstance(term.type,
(NullType, DisconnectedType)):
if term.type.dtype.find('float') == -1:
if term.type.dtype not in theano.tensor.float_dtypes:
raise TypeError(str(node.op) + '.grad illegally '
' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i,
......@@ -997,8 +1064,18 @@ def grad_sources_inputs(sources, graph_inputs):
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
# by default, the gradient of the cost is 1
for output, output_grad in sources:
# The gradient of the cost should always be 0 if the cost is of
# discrete (integer) dtype.
if getattr(output.type, 'dtype', '') not in theano.tensor.float_dtypes:
output_grad = output.zeros_like()
else:
# Cast the provided gradient so that it has the same dtype
# as the cost.
output_grad = output_grad.astype(output.type.dtype)
grad_dict[output] = output_grad
# variables that do not influence the cost have zero gradient.
......@@ -1369,12 +1446,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
cost_fn = function(tensor_pt, cost)
# todo-- determine if this is actually needed
g_cost = as_tensor_variable(1.0, name='g_cost')
if cast_to_output_type:
g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost,
symbolic_grad = grad(cost, tensor_pt,
disconnected_inputs='ignore')
grad_fn = function(tensor_pt, symbolic_grad)
......
......@@ -1966,10 +1966,18 @@ class TensorFromScalar(Op):
def grad(self, inp, grads):
s, = inp
dt, = grads
assert dt.type.dtype.find('float') != -1
if s.type.dtype.find('int') != -1:
if s.type.dtype in float_dtypes:
assert dt.type.dtype in float_dtypes
return [scalar_from_tensor(dt)]
# If the input dtype is an integer, then so is the output dtype,
# and the "zero" gradient can be represented in that int dtype.
# Currently, theano.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(theano.config.floatX)]
return [scalar_from_tensor(dt)]
raise NotImplementedError("grad not implemented for complex dtypes")
def __str__(self):
return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论