提交 7681f4c7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added known_grads parameter to grad function

上级 26520161
...@@ -384,9 +384,11 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -384,9 +384,11 @@ def Lop(f, wrt, eval_points, consider_constant=None,
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=None, def grad(cost, wrt, g_cost=None, consider_constant=None,
disconnected_inputs='raise', add_names=True): disconnected_inputs='raise', add_names=True,
known_grads=None):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided.
:type wrt: Variable or list of Variables. :type wrt: Variable or list of Variables.
:type g_cost: Scalar Variable, or None. :type g_cost: Scalar Variable, or None.
:param g_cost: an expression for the gradient through cost. The default is :param g_cost: an expression for the gradient through cost. The default is
...@@ -407,6 +409,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -407,6 +409,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
(d<cost.name>/d<wrt.name>) provided that both cost and wrt have (d<cost.name>/d<wrt.name>) provided that both cost and wrt have
names names
:type known_grads: dict
:param known_grads: If not None, a dictionary mapping variables to their
gradients. This is useful in the case where you know the
gradient on some variables but do not know the original
cost.
:rtype: Variable or list/tuple of Variables (depending upon `wrt`) :rtype: Variable or list/tuple of Variables (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
...@@ -420,12 +428,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -420,12 +428,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
if tensor is None: if tensor is None:
from theano import tensor from theano import tensor
if isinstance(cost.type, NullType): if cost is None:
assert known_grads is not None
if cost is not None and isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost." raise ValueError("Can't differentiate a NaN cost."
"cost is NaN because " + \ "cost is NaN because " + \
cost.type.why_null) cost.type.why_null)
if cost.ndim != 0: if cost is not None and cost.ndim != 0:
raise TypeError("cost must be a scalar.") raise TypeError("cost must be a scalar.")
...@@ -444,7 +455,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -444,7 +455,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
raise TypeError("Expected Variable, got " + str(elem) + raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem))) " of type "+str(type(elem)))
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt, consider_constant) outputs = []
if cost is not None:
outputs.append(cost)
if known_grads is not None:
outputs.extend(known_grads.keys())
var_to_node_to_idx = _populate_var_to_node_to_idx(
outputs, wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = {} grad_dict = {}
...@@ -452,6 +470,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -452,6 +470,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# The gradient of the cost should default to 1 if the cost is of a # The gradient of the cost should default to 1 if the cost is of a
# continuous dtype (float, for the moment, as complex are unsupported), # continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype. # and should always be 0 if the cost is of discrete (integer) dtype.
if cost is not None:
if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes: if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes:
if g_cost is not None: if g_cost is not None:
try: try:
...@@ -464,11 +483,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -464,11 +483,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
g_cost_is_zero = False g_cost_is_zero = False
if not g_cost_is_zero: if not g_cost_is_zero:
raise ValueError("The gradient of a cost of non-continuous " raise ValueError(
"dtype (here, %s), if it is defined, should be 0. " "The gradient of a cost of non-continuous dtype "
"However, a value of %s was provided in the 'g_cost' " "(here, %s), if it is defined, should be 0. "
"argument of theano.grad(). To remove this error, " "However, a value of %s was provided in the "
"you can simply omit the 'g_cost' argument, or " "'g_cost' "
"argument of theano.grad(). To remove this error,"
" you can simply omit the 'g_cost' argument, or "
"give it the default value of None." % ( "give it the default value of None." % (
getattr(g_cost.type, 'dtype', 'no dtype defined'), getattr(g_cost.type, 'dtype', 'no dtype defined'),
g_cost)) g_cost))
...@@ -484,6 +505,28 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -484,6 +505,28 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
grad_dict[cost] = g_cost grad_dict[cost] = g_cost
else:
if g_cost is not None:
raise ValueError("No cost node was specified, but a gradient"
" on it was.")
if known_grads is not None:
for var in known_grads:
g_var = known_grads[var]
if not hasattr(g_var, 'type'):
raise TypeError('output grads must be theano variables.'
'Ambiguous whether %s should be made into tensor'
' or sparse theano variable' % str(type(g_var)))
if g_var.type not in [NullType, DisconnectedType] and 'float' \
not in str(g_var.type.dtype):
raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous.")
grad_dict[var] = g_var
# variables that do not influence the cost have zero gradient. # variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info # if wrt is such a variable, populate the grad_dict with this info
...@@ -573,7 +616,7 @@ def _node_to_pattern(node): ...@@ -573,7 +616,7 @@ def _node_to_pattern(node):
def _populate_var_to_node_to_idx(outputs, wrt, consider_constant): def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
""" """
Common code shared between grad and grad_sources_inputs Helper function for grad function.
outputs: a list of variables we want to take gradients of outputs: a list of variables we want to take gradients of
...@@ -713,7 +756,7 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant): ...@@ -713,7 +756,7 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
def _populate_grad_dict(var_to_node_to_idx, def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None): grad_dict, wrt, cost_name=None):
""" """
Common code shared between grad_sources_inputs and grad Helper function for grad function.
var_to_node_to_idx: a dictionary mapping a variable to var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary. a second dictionary.
...@@ -722,7 +765,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -722,7 +765,7 @@ def _populate_grad_dict(var_to_node_to_idx,
node's input list node's input list
grad_dict: a dictionary mapping variables to their gradients grad_dict: a dictionary mapping variables to their gradients
should be populated by grad or grad_sources_inputs should be populated by grad function.
grad should set gradients to DisconnectedType()() for grad should set gradients to DisconnectedType()() for
variables to be considered constant, set the variables to be considered constant, set the
...@@ -873,6 +916,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -873,6 +916,7 @@ def _populate_grad_dict(var_to_node_to_idx,
'the grad_undefined or grad_unimplemented helper ' 'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op) 'functions.') % node.op)
if not isinstance(term.type, if not isinstance(term.type,
(NullType, DisconnectedType)): (NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes: if term.type.dtype not in theano.tensor.float_dtypes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论