提交 51290164 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1073 from goodfeli/remove_g_cost

Ready to merge: remove g_cost
......@@ -349,16 +349,13 @@ def Lop(f, wrt, eval_points, consider_constant=None,
# Gradient
#########################
def grad(cost, wrt, g_cost=None, consider_constant=None,
def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero'):
"""
:type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided.
:type wrt: Variable or list of Variables.
:type g_cost: Scalar Variable, or None.
:param g_cost: an expression for the gradient through cost. The default is
``ones_like(cost)``.
:param consider_constant: a list of expressions not to backpropagate
through
......@@ -441,11 +438,14 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
# The gradient of the cost should default to 1 if the cost is of a
# continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype.
if known_grads is None:
known_grads = {}
# The gradient of the cost is 1 unless specified otherwise by known_grads.
if cost is not None:
if g_cost is None:
if cost in known_grads:
g_cost = known_grads[cost]
else:
g_cost = _float_ones_like(cost)
# g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try
......@@ -459,32 +459,27 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
assert g_cost not in tensor.discrete_dtypes
grad_dict[cost] = g_cost
else:
if g_cost is not None:
raise ValueError("No cost node was specified, but a gradient"
" on it was.")
if known_grads is not None:
for var in known_grads:
g_var = known_grads[var]
for var in known_grads:
g_var = known_grads[var]
if not hasattr(g_var, 'type'):
raise TypeError('output grads must be theano variables.'
'Ambiguous whether %s should be made into tensor'
' or sparse theano variable' % str(type(g_var)))
if not hasattr(g_var, 'type'):
raise TypeError('output grads must be theano variables.'
'Ambiguous whether %s should be made into tensor'
' or sparse theano variable' % str(type(g_var)))
if g_var.type not in [NullType, DisconnectedType] and 'float' \
not in str(g_var.type.dtype):
raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous, but grad was "
"given a known_grad of type "+str(g_var.type))
if g_var.type not in [NullType, DisconnectedType] and 'float' \
not in str(g_var.type.dtype):
raise TypeError("Gradients must always be NullType, "
"DisconnectedType, or continuous, but grad was "
"given a known_grad of type "+str(g_var.type))
# DO NOT check that these gradients are equal to 0 if var is int
# The gradient is allowed to be non-zero on var in that case
# Ops outputing var should not backpropagate its gradient further
# but that is enforced elsewhere (grep for only_connected_to_int)
# DO NOT check that these gradients are equal to 0 if var is int
# The gradient is allowed to be non-zero on var in that case
# Ops outputing var should not backpropagate its gradient further
# but that is enforced elsewhere (grep for only_connected_to_int)
grad_dict[var] = g_var
grad_dict[var] = g_var
......@@ -737,14 +732,12 @@ def _populate_grad_dict(var_to_node_to_idx,
this variable to the variable's index in the apply
node's input list
grad_dict: a dictionary mapping variables to their gradients
should be populated by grad function.
grad should set gradients to DisconnectedType()() for
variables to be considered constant, set the
gradient for the cost variable to g_cost, etc.
both should set the gradient for disconnected
grad_dict: A dictionary mapping variables to their gradients.
Should be populated by grad function, which should:
-Set the gradient with respect to the cost to 1
-Load all gradients from known_grads, possibly overriding
the cost
-Set the gradient for disconnected
inputs to a variable with type DisconnectedType()
wrt: the minimal set of variables that must be included in grad_dict
......
......@@ -1046,7 +1046,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# Verify the gradient when providing output gradient
h = theano.function([x, y, a],
T.grad(expr, x, g_cost=a * x.sum()), mode=mode)
T.grad(expr, x, known_grads={expr:a * x.sum()}), mode=mode)
try:
assert 8 <= len(h.maker.fgraph.toposort()) <= 17
validate_grad_graph(h)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论