提交 76e6de02 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

simplify handling of known_grads

上级 a931005f
...@@ -438,6 +438,9 @@ def grad(cost, wrt, consider_constant=None, ...@@ -438,6 +438,9 @@ def grad(cost, wrt, consider_constant=None,
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = {} grad_dict = {}
if known_grads is None:
known_grads = {}
# The gradient of the cost is 1 unless specified otherwise by known_grads. # The gradient of the cost is 1 unless specified otherwise by known_grads.
if cost is not None: if cost is not None:
if cost in known_grads: if cost in known_grads:
...@@ -457,27 +460,26 @@ def grad(cost, wrt, consider_constant=None, ...@@ -457,27 +460,26 @@ def grad(cost, wrt, consider_constant=None,
grad_dict[cost] = g_cost grad_dict[cost] = g_cost
if known_grads is not None: for var in known_grads:
for var in known_grads: g_var = known_grads[var]
g_var = known_grads[var]
if not hasattr(g_var, 'type'):
if not hasattr(g_var, 'type'): raise TypeError('output grads must be theano variables.'
raise TypeError('output grads must be theano variables.' 'Ambiguous whether %s should be made into tensor'
'Ambiguous whether %s should be made into tensor' ' or sparse theano variable' % str(type(g_var)))
' or sparse theano variable' % str(type(g_var)))
if g_var.type not in [NullType, DisconnectedType] and 'float' \
if g_var.type not in [NullType, DisconnectedType] and 'float' \ not in str(g_var.type.dtype):
not in str(g_var.type.dtype): raise TypeError("Gradients must always be NullType, "
raise TypeError("Gradients must always be NullType, " "DisconnectedType, or continuous, but grad was "
"DisconnectedType, or continuous, but grad was " "given a known_grad of type "+str(g_var.type))
"given a known_grad of type "+str(g_var.type))
# DO NOT check that these gradients are equal to 0 if var is int
# DO NOT check that these gradients are equal to 0 if var is int # The gradient is allowed to be non-zero on var in that case
# The gradient is allowed to be non-zero on var in that case # Ops outputing var should not backpropagate its gradient further
# Ops outputing var should not backpropagate its gradient further # but that is enforced elsewhere (grep for only_connected_to_int)
# but that is enforced elsewhere (grep for only_connected_to_int)
grad_dict[var] = g_var
grad_dict[var] = g_var
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论