提交 e8a66108 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

simplified Lop's handling of consider_constant

added some input validation to Lop
上级 6c07f25b
...@@ -318,9 +318,6 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -318,9 +318,6 @@ def Lop(f, wrt, eval_points, consider_constant=None,
coordinates of the tensor element in the last coordinates of the tensor element in the last
If `f` is a list/tuple, then return a list/tuple with the results. If `f` is a list/tuple, then return a list/tuple with the results.
""" """
if consider_constant is None:
consider_constant = []
if type(eval_points) not in (list, tuple): if type(eval_points) not in (list, tuple):
eval_points = [eval_points] eval_points = [eval_points]
...@@ -334,48 +331,15 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -334,48 +331,15 @@ def Lop(f, wrt, eval_points, consider_constant=None,
f = list(f) f = list(f)
grads = list(eval_points) grads = list(eval_points)
for elem in consider_constant:
assert elem not in f
f.append(elem)
grads.append(elem.zeros_like())
if not isinstance(wrt, (list, tuple)): if not isinstance(wrt, (list, tuple)):
wrt = [wrt] wrt = [wrt]
known = dict(zip(f, eval_points)) assert len(f) == len(grads)
known = dict(zip(f, grads))
gmap = dict(zip(wrt, grad(cost=None, known_grads=known,
consider_constant=wrt, wrt=wrt))) ret = grad(cost=None, known_grads=known,
consider_constant=consider_constant, wrt=wrt,
# Note : If p is not in gmap there can be several reasons, among which disconnected_inputs=disconnected_inputs)
# is the fact that p might not be part of the computational graph. A
# simple example is that for a+b for e.g. a[0] is not part of the graph,
# so Theano does not know how to compute TT.grad(TT.sum(a+b), a[0])
# such subtle cases can be fixed by a more careful implementation of the
# gradient, but for now Theano needs to throw an exception, and make the
# user aware that it does not know how to compute that gradient
ret = []
for p in wrt:
if p in gmap:
ret.append(gmap[p])
else:
message = (
"Lop method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % p)
if disconnected_inputs == 'ignore':
pass
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=1)
elif disconnected_inputs == 'raise':
raise ValueError(message)
else:
raise ValueError(
"Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
ret.append(p.zeros_like())
return format_as(using_list, using_tuple, ret) return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论