提交 c53a8e84 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed what looks like a bug in Lop where it looks like it was (never)

handling consider_constant correctly made Lop work nicely with new grad_sources_inputs to not compute unnecessary values
上级 5a04235b
...@@ -288,10 +288,23 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -288,10 +288,23 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
if not isinstance(f, (list, tuple)): if not isinstance(f, (list, tuple)):
f = [f] f = [f]
inputs = gof.graph.inputs(f) f = [ elem for elem in f ]
grads = [ elem for elem in eval_points ]
for elem in consider_constant:
assert elem not in f
f.append(elem)
grads.append(elem.zeros_like())
if not isinstance(wrt, (list, tuple)):
wrt = [wrt]
arg1 = zip(f, eval_points)
arg2 = list(wrt)
gmap = grad_sources_inputs( gmap = grad_sources_inputs(
zip(f, eval_points), arg1,
list(inputs) + list(consider_constant), arg2,
warn_type=warn_type) warn_type=warn_type)
# Note : If p is not in gmap there can be several reasons, among which # Note : If p is not in gmap there can be several reasons, among which
...@@ -301,8 +314,6 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -301,8 +314,6 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# such subtle cases can be fixed by a more careful implementation of the # such subtle cases can be fixed by a more careful implementation of the
# gradient, but for now Theano needs to throw an exception, and make the # gradient, but for now Theano needs to throw an exception, and make the
# user aware that it does not know how to compute that gradient # user aware that it does not know how to compute that gradient
if not isinstance(wrt, (list, tuple)):
wrt = [wrt]
ret = [] ret = []
for p in wrt: for p in wrt:
if p in gmap: if p in gmap:
...@@ -400,6 +411,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -400,6 +411,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
if not using_list and not using_tuple: if not using_list and not using_tuple:
wrt = [ wrt ] wrt = [ wrt ]
var_to_node_to_idx = _populate_var_to_node_to_idx([cost]) var_to_node_to_idx = _populate_var_to_node_to_idx([cost])
#build a dict mapping var to the gradient of cost with respect to var #build a dict mapping var to the gradient of cost with respect to var
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论