提交 7681f4c7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added known_grads parameter to grad function

上级 26520161
......@@ -384,9 +384,11 @@ def Lop(f, wrt, eval_points, consider_constant=None,
#########################
def grad(cost, wrt, g_cost=None, consider_constant=None,
disconnected_inputs='raise', add_names=True):
disconnected_inputs='raise', add_names=True,
known_grads=None):
"""
:type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided.
:type wrt: Variable or list of Variables.
:type g_cost: Scalar Variable, or None.
:param g_cost: an expression for the gradient through cost. The default is
......@@ -407,6 +409,12 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
(d<cost.name>/d<wrt.name>) provided that both cost and wrt have
names
:type known_grads: dict
:param known_grads: If not None, a dictionary mapping variables to their
gradients. This is useful in the case where you know the
gradient on some variables but do not know the original
cost.
:rtype: Variable or list/tuple of Variables (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
......@@ -420,12 +428,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
if tensor is None:
from theano import tensor
if isinstance(cost.type, NullType):
if cost is None:
assert known_grads is not None
if cost is not None and isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost."
"cost is NaN because " + \
cost.type.why_null)
if cost.ndim != 0:
if cost is not None and cost.ndim != 0:
raise TypeError("cost must be a scalar.")
......@@ -444,7 +455,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem)))
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt, consider_constant)
outputs = []
if cost is not None:
outputs.append(cost)
if known_grads is not None:
outputs.extend(known_grads.keys())
var_to_node_to_idx = _populate_var_to_node_to_idx(
outputs, wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
......@@ -452,6 +470,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# The gradient of the cost should default to 1 if the cost is of a
# continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype.
if cost is not None:
if getattr(cost.type, 'dtype', None) not in tensor.float_dtypes:
if g_cost is not None:
try:
......@@ -464,11 +483,13 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
g_cost_is_zero = False
if not g_cost_is_zero:
raise ValueError("The gradient of a cost of non-continuous "
"dtype (here, %s), if it is defined, should be 0. "
"However, a value of %s was provided in the 'g_cost' "
"argument of theano.grad(). To remove this error, "
"you can simply omit the 'g_cost' argument, or "
raise ValueError(
"The gradient of a cost of non-continuous dtype "
"(here, %s), if it is defined, should be 0. "
"However, a value of %s was provided in the "
"'g_cost' "
"argument of theano.grad(). To remove this error,"
" you can simply omit the 'g_cost' argument, or "
"give it the default value of None." % (
getattr(g_cost.type, 'dtype', 'no dtype defined'),
g_cost))
......@@ -484,6 +505,28 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
g_cost = g_cost.astype(cost.type.dtype)
grad_dict[cost] = g_cost
else:
if g_cost is not None:
raise ValueError("No cost node was specified, but a gradient"
" on it was.")
if known_grads is not None:
for var in known_grads:
g_var = known_grads[var]
if not hasattr(g_var, 'type'):
raise TypeError('output grads must be theano variables.'
'Ambiguous whether %s should be made into tensor'
' or sparse theano variable' % str(type(g_var)))
if g_var.type not in [NullType, DisconnectedType] and 'float' \
not in str(g_var.type.dtype):
raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous.")
grad_dict[var] = g_var
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
......@@ -573,7 +616,7 @@ def _node_to_pattern(node):
def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
"""
Common code shared between grad and grad_sources_inputs
Helper function for grad function.
outputs: a list of variables we want to take gradients of
......@@ -713,7 +756,7 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None):
"""
Common code shared between grad_sources_inputs and grad
Helper function for grad function.
var_to_node_to_idx: a dictionary mapping a variable to
a second dictionary.
......@@ -722,7 +765,7 @@ def _populate_grad_dict(var_to_node_to_idx,
node's input list
grad_dict: a dictionary mapping variables to their gradients
should be populated by grad or grad_sources_inputs
should be populated by grad function.
grad should set gradients to DisconnectedType()() for
variables to be considered constant, set the
......@@ -873,6 +916,7 @@ def _populate_grad_dict(var_to_node_to_idx,
'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op)
if not isinstance(term.type,
(NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论