提交 a931005f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

remove g_cost

上级 40bbb7da
......@@ -349,16 +349,13 @@ def Lop(f, wrt, eval_points, consider_constant=None,
# Gradient
#########################
def grad(cost, wrt, g_cost=None, consider_constant=None,
def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero'):
"""
:type cost: Scalar (0-dimensional) Variable.
May optionally be None if known_grads is provided.
:type wrt: Variable or list of Variables.
:type g_cost: Scalar Variable, or None.
:param g_cost: an expression for the gradient through cost. The default is
``ones_like(cost)``.
:param consider_constant: a list of expressions not to backpropagate
through
......@@ -441,11 +438,11 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
# The gradient of the cost should default to 1 if the cost is of a
# continuous dtype (float, for the moment, as complex are unsupported),
# and should always be 0 if the cost is of discrete (integer) dtype.
# The gradient of the cost is 1 unless specified otherwise by known_grads.
if cost is not None:
if g_cost is None:
if cost in known_grads:
g_cost = known_grads[cost]
else:
g_cost = _float_ones_like(cost)
# g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try
......@@ -459,10 +456,6 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
assert g_cost not in tensor.discrete_dtypes
grad_dict[cost] = g_cost
else:
if g_cost is not None:
raise ValueError("No cost node was specified, but a gradient"
" on it was.")
if known_grads is not None:
for var in known_grads:
......@@ -737,14 +730,12 @@ def _populate_grad_dict(var_to_node_to_idx,
this variable to the variable's index in the apply
node's input list
grad_dict: a dictionary mapping variables to their gradients
should be populated by grad function.
grad should set gradients to DisconnectedType()() for
variables to be considered constant, set the
gradient for the cost variable to g_cost, etc.
both should set the gradient for disconnected
grad_dict: A dictionary mapping variables to their gradients.
Should be populated by grad function, which should:
-Set the gradient with respect to the cost to 1
-Load all gradients from known_grads, possibly overriding
the cost
-Set the gradient for disconnected
inputs to a variable with type DisconnectedType()
wrt: the minimal set of variables that must be included in grad_dict
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论