提交 7681f4c7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added known_grads parameter to grad function

上级 26520161
...@@ -384,9 +384,11 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -384,9 +384,11 @@ def Lop(f, wrt, eval_points, consider_constant=None,
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=None, def grad(cost, wrt, g_cost=None, consider_constant=None,
disconnected_inputs='raise', add_names=True): disconnected_inputs='raise', add_names=True,
known_grads=None):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided.
:type wrt: Variable or list of Variables. :type wrt: Variable or list of Variables.
:type g_cost: Scalar Variable, or None. :type g_cost: Scalar Variable, or None.
:param g_cost: an expression for the gradient through cost. The default is :param g_cost: an expression for the gradient through cost. The default is
...@@ -407,6 +409,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -407,6 +409,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
(d<cost.name>/d<wrt.name>) provided that both cost and wrt have (d<cost.name>/d<wrt.name>) provided that both cost and wrt have
names names
:type known_grads: dict
:param known_grads: If not None, a dictionary mapping variables to their
gradients. This is useful in the case where you know the
gradient on some variables but do not know the original
cost.
:rtype: Variable or list/tuple of Variables (depending upon `wrt`) :rtype: Variable or list/tuple of Variables (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
...@@ -420,12 +428,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -420,12 +428,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
if tensor is None: if tensor is None:
from theano import tensor from theano import tensor
if isinstance(cost.type, NullType): if cost is None:
assert known_grads is not None
if cost is not None and isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost." raise ValueError("Can't differentiate a NaN cost."
"cost is NaN because " + \ "cost is NaN because " + \
cost.type.why_null) cost.type.why_null)
if cost.ndim != 0: if cost is not None and cost.ndim != 0:
raise TypeError("cost must be a scalar.") raise TypeError("cost must be a scalar.")
...@@ -444,7 +455,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -444,7 +455,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
raise TypeError("Expected Variable, got " + str(elem) + raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem))) " of type "+str(type(elem)))
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt, consider_constant) outputs = []
if cost is not None:
outputs.append(cost)
if known_grads is not None:
outputs.extend(known_grads.keys())
var_to_node_to_idx = _populate_var_to_node_to_idx(
outputs, wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = {} grad_dict = {}
...@@ -452,38 +470,63 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -452,38 +470,63 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# The gradient of the cost should default to 1 if the cost is of a # The gradient of the cost should default to 1 if the cost is of a
# continuous dtype (float, for the moment, as complex are unsupported), # continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype. # and should always be 0 if the cost is of discrete (integer) dtype.
if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes: if cost is not None:
if g_cost is not None: if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes:
try: if g_cost is not None:
cval = theano.get_constant_value(g_cost) try:
if cval == 0: cval = theano.get_constant_value(g_cost)
g_cost_is_zero = True if cval == 0:
else: g_cost_is_zero = True
else:
g_cost_is_zero = False
except TypeError:
g_cost_is_zero = False g_cost_is_zero = False
except TypeError:
g_cost_is_zero = False if not g_cost_is_zero:
raise ValueError(
if not g_cost_is_zero: "The gradient of a cost of non-continuous dtype "
raise ValueError("The gradient of a cost of non-continuous " "(here, %s), if it is defined, should be 0. "
"dtype (here, %s), if it is defined, should be 0. " "However, a value of %s was provided in the "
"However, a value of %s was provided in the 'g_cost' " "'g_cost' "
"argument of theano.grad(). To remove this error, " "argument of theano.grad(). To remove this error,"
"you can simply omit the 'g_cost' argument, or " " you can simply omit the 'g_cost' argument, or "
"give it the default value of None." % ( "give it the default value of None." % (
getattr(g_cost.type, 'dtype', 'no dtype defined'), getattr(g_cost.type, 'dtype', 'no dtype defined'),
g_cost)) g_cost))
g_cost = tensor.zeros_like(cost) g_cost = tensor.zeros_like(cost)
elif g_cost is None: elif g_cost is None:
# cost.type.dtype is in tensor.float_dtypes at that point # cost.type.dtype is in tensor.float_dtypes at that point
g_cost = tensor.ones_like(cost) g_cost = tensor.ones_like(cost)
else:
# Cast the provided gradient so that it has the same dtype
# as the cost.
g_cost = g_cost.astype(cost.type.dtype)
grad_dict[cost] = g_cost
else: else:
# Cast the provided gradient so that it has the same dtype if g_cost is not None:
# as the cost. raise ValueError("No cost node was specified, but a gradient"
g_cost = g_cost.astype(cost.type.dtype) " on it was.")
if known_grads is not None:
for var in known_grads:
g_var = known_grads[var]
if not hasattr(g_var, 'type'):
raise TypeError('output grads must be theano variables.'
'Ambiguous whether %s should be made into tensor'
' or sparse theano variable' % str(type(g_var)))
if g_var.type not in [NullType, DisconnectedType] and 'float' \
not in str(g_var.type.dtype):
raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous.")
grad_dict[var] = g_var
grad_dict[cost] = g_cost
# variables that do not influence the cost have zero gradient. # variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info # if wrt is such a variable, populate the grad_dict with this info
...@@ -573,7 +616,7 @@ def _node_to_pattern(node): ...@@ -573,7 +616,7 @@ def _node_to_pattern(node):
def _populate_var_to_node_to_idx(outputs, wrt, consider_constant): def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
""" """
Common code shared between grad and grad_sources_inputs Helper function for grad function.
outputs: a list of variables we want to take gradients of outputs: a list of variables we want to take gradients of
...@@ -713,7 +756,7 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant): ...@@ -713,7 +756,7 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
def _populate_grad_dict(var_to_node_to_idx, def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None): grad_dict, wrt, cost_name=None):
""" """
Common code shared between grad_sources_inputs and grad Helper function for grad function.
var_to_node_to_idx: a dictionary mapping a variable to var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary. a second dictionary.
...@@ -722,7 +765,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -722,7 +765,7 @@ def _populate_grad_dict(var_to_node_to_idx,
node's input list node's input list
grad_dict: a dictionary mapping variables to their gradients grad_dict: a dictionary mapping variables to their gradients
should be populated by grad or grad_sources_inputs should be populated by grad function.
grad should set gradients to DisconnectedType()() for grad should set gradients to DisconnectedType()() for
variables to be considered constant, set the variables to be considered constant, set the
...@@ -873,6 +916,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -873,6 +916,7 @@ def _populate_grad_dict(var_to_node_to_idx,
'the grad_undefined or grad_unimplemented helper ' 'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op) 'functions.') % node.op)
if not isinstance(term.type, if not isinstance(term.type,
(NullType, DisconnectedType)): (NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes: if term.type.dtype not in theano.tensor.float_dtypes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论