提交 6ab9b93f authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: Ian Goodfellow

Rop has as many outputs as f and Lop as wrt which needs to be reflected in the list behaviour

上级 2cdfc327
...@@ -50,16 +50,18 @@ def Rop(f, wrt, eval_points): ...@@ -50,16 +50,18 @@ def Rop(f, wrt, eval_points):
If `wrt` is a list/tuple, then return a list/tuple with the results. If `wrt` is a list/tuple, then return a list/tuple with the results.
""" """
using_list = isinstance(wrt, list) using_list = isinstance(f, list)
using_tuple = isinstance(wrt, tuple) using_tuple = isinstance(f, tuple)
if not (using_list or using_tuple):
if not isinstance(wrt, (list,tuple)):
wrt = [ wrt ] wrt = [ wrt ]
if not isinstance(eval_points, (list, tuple)): if not isinstance(eval_points, (list, tuple)):
eval_points = [ eval_points ] eval_points = [ eval_points ]
if not isinstance(f, (list,tuple)):
if not (using_list or using_tuple):
f = [f] f = [f]
assert len(wrt) == len(eval_points) assert len(wrt) == len(eval_points)
...@@ -175,9 +177,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -175,9 +177,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
if type(eval_points) not in (list, tuple): if type(eval_points) not in (list, tuple):
eval_points = [eval_points] eval_points = [eval_points]
using_list = isinstance(f, list) using_list = isinstance(wrt, list)
using_tuple = isinstance(f, tuple) using_tuple = isinstance(wrt, tuple)
if not (using_list or using_tuple):
if not isinstance(f, (list, tuple)):
f = [f] f = [f]
inputs = gof.graph.inputs(f) inputs = gof.graph.inputs(f)
...@@ -193,7 +196,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -193,7 +196,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# such subtle cases can be fixed by a more careful implementation of the # such subtle cases can be fixed by a more careful implementation of the
# gradient, but for now Theano needs to throw an exception, and make the # gradient, but for now Theano needs to throw an exception, and make the
# user aware that it does not know how to compute that gradient # user aware that it does not know how to compute that gradient
if not isinstance(wrt, (list, tuple)):
if not (using_list or using_tuple):
wrt = [wrt] wrt = [wrt]
ret = [] ret = []
for p in wrt: for p in wrt:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论