提交 3b38e2c9 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

changed tensor.grad to have list-in, list-out behavior

上级 ce20b8ba
...@@ -268,7 +268,9 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -268,7 +268,9 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
# such subtle cases can be fixed by a more careful implementation of the # such subtle cases can be fixed by a more careful implementation of the
# gradient, but for now Theano needs to throw an exception, and make the # gradient, but for now Theano needs to throw an exception, and make the
# user aware that it does not know how to compute that gradient # user aware that it does not know how to compute that gradient
if not isinstance(wrt, (list, tuple)): using_list = isinstance(wrt, list)
using_tuple = isinstance(list, tuple)
if not (using_list or using_tuple):
wrt = [wrt] wrt = [wrt]
ret = [] ret = []
for p in wrt: for p in wrt:
...@@ -292,8 +294,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -292,8 +294,15 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
ret.append(zeros_like(p)) ret.append(zeros_like(p))
if len(ret) == 1: if len(ret) == 1:
return ret[0] if using_list:
return ret
elif using_tuple:
return tuple(ret)
else:
return ret[0]
else: else:
if using_tuple:
return tuple(ret)
return ret return ret
class numeric_grad: class numeric_grad:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论