提交 a931005f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

remove g_cost

上级 40bbb7da
...@@ -349,16 +349,13 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -349,16 +349,13 @@ def Lop(f, wrt, eval_points, consider_constant=None,
# Gradient # Gradient
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=None, def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True, disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero'): known_grads=None, return_disconnected='zero'):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided. May optionally be None if known_grads is provided.
:type wrt: Variable or list of Variables. :type wrt: Variable or list of Variables.
:type g_cost: Scalar Variable, or None.
:param g_cost: an expression for the gradient through cost. The default is
``ones_like(cost)``.
:param consider_constant: a list of expressions not to backpropagate :param consider_constant: a list of expressions not to backpropagate
through through
...@@ -441,11 +438,11 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -441,11 +438,11 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = {} grad_dict = {}
# The gradient of the cost should default to 1 if the cost is of a # The gradient of the cost is 1 unless specified otherwise by known_grads.
# continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype.
if cost is not None: if cost is not None:
if g_cost is None: if cost in known_grads:
g_cost = known_grads[cost]
else:
g_cost = _float_ones_like(cost) g_cost = _float_ones_like(cost)
# g_cost may be Disconnected or NullType. A creative use of the function, # g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try # sure, but nonetheless one we can and should support. So before we try
...@@ -459,10 +456,6 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -459,10 +456,6 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
assert g_cost not in tensor.discrete_dtypes assert g_cost not in tensor.discrete_dtypes
grad_dict[cost] = g_cost grad_dict[cost] = g_cost
else:
if g_cost is not None:
raise ValueError("No cost node was specified, but a gradient"
" on it was.")
if known_grads is not None: if known_grads is not None:
for var in known_grads: for var in known_grads:
...@@ -737,14 +730,12 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -737,14 +730,12 @@ def _populate_grad_dict(var_to_node_to_idx,
this variable to the variable's index in the apply this variable to the variable's index in the apply
node's input list node's input list
grad_dict: a dictionary mapping variables to their gradients grad_dict: A dictionary mapping variables to their gradients.
should be populated by grad function. Should be populated by grad function, which should:
-Set the gradient with respect to the cost to 1
grad should set gradients to DisconnectedType()() for -Load all gradients from known_grads, possibly overriding
variables to be considered constant, set the the cost
gradient for the cost variable to g_cost, etc. -Set the gradient for disconnected
both should set the gradient for disconnected
inputs to a variable with type DisconnectedType() inputs to a variable with type DisconnectedType()
wrt: the minimal set of variables that must be included in grad_dict wrt: the minimal set of variables that must be included in grad_dict
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论