提交 3b546b82 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made adding names optional

上级 0ec49804
...@@ -328,7 +328,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -328,7 +328,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
######################### #########################
def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
disconnected_inputs = 'raise'): disconnected_inputs = 'raise', add_names = True):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
:type wrt: Variable or list of Variables. :type wrt: Variable or list of Variables.
...@@ -349,6 +349,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -349,6 +349,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
- 'warn': consider the gradient zero, and print a warning. - 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception. - 'raise': raise an exception.
:type add_names: bool
:param add_names: If True, variables generated by grad will be named
(d<cost.name>/d<wrt.name>) provided that both cost and wrt have
names
:rtype: Variable or list/tuple of Variables (depending upon `wrt`) :rtype: Variable or list/tuple of Variables (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`. :return: symbolic expression of gradient of `cost` with respect to `wrt`.
...@@ -426,9 +431,13 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -426,9 +431,13 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
grad_dict[elem] = DisconnectedType()() grad_dict[elem] = DisconnectedType()()
cost_name = None
if add_names:
cost_name = cost.name
rval = _populate_grad_dict(var_to_node_to_idx, rval = _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type, grad_dict, wrt, warn_type,
cost.name) cost_name)
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论