提交 572c7878 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added function that reformats the output

Following olivier's suggestion, I refractored the code that retransforms the output in either a list or tuple
上级 a018cdea
...@@ -23,6 +23,23 @@ from theano import compile ...@@ -23,6 +23,23 @@ from theano import compile
_logger = logging.getLogger('theano.tensor.tensor_grad') _logger = logging.getLogger('theano.tensor.tensor_grad')
def format_as(use_list, use_tuple, outputs):
if (use_list or use_tuple) and not isinstance(outputs, (list, tuple)):
if use_list:
return [outputs]
else:
return (outputs,)
elif not (use_list or use_tuple) and isinstance(outputs, (list, tuple)):
return outputs[0]
elif use_list or use_tuple:
if use_list:
return list(outputs)
else:
return tuple(outputs)
else:
return outputs
######################## ########################
# R Operator # R Operator
######################## ########################
...@@ -138,16 +155,7 @@ def Rop(f, wrt, eval_points): ...@@ -138,16 +155,7 @@ def Rop(f, wrt, eval_points):
else: else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
if len(rval) == 1: return format_as(using_list, using_tuple, rval)
if using_list:
return rval
if using_tuple:
return tuple(rval)
return rval[0]
else:
if using_tuple:
return tuple(rval)
return rval
def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
...@@ -227,17 +235,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -227,17 +235,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p)) ret.append(zeros_like(p))
if len(ret) == 1: return format_as(using_list, using_tuple, ret)
if using_list:
return ret
elif using_tuple:
return tuple(ret)
else:
return ret[0]
else:
if using_tuple:
return tuple(ret)
return ret
######################### #########################
...@@ -344,16 +342,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -344,16 +342,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p)) ret.append(zeros_like(p))
if len(ret) == 1 and not (using_list or using_tuple): return format_as(using_list, using_tuple, ret)
# `wrt` was a single Variable, so we return a single Variable too.
return ret[0]
else:
# Ensure we preserve the original type of `wrt`.
if using_tuple:
return tuple(ret)
else:
assert using_list
return ret
class numeric_grad(object): class numeric_grad(object):
...@@ -771,12 +760,7 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False, ...@@ -771,12 +760,7 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False,
jacobs, _ = theano.scan(inner_function, jacobs, _ = theano.scan(inner_function,
sequences=arange(expression.shape[0]), sequences=arange(expression.shape[0]),
non_sequences=[expression] + wrt) non_sequences=[expression] + wrt)
if use_list and not isinstance(jacobs, (list, tuple)): return format_as(using_list, using_tuple, jacobs)
return [jacobs]
elif not use_list and isinstance(jacobs, (list, tuple)):
return jacobs[0]
else:
return jacobs
def hessian(cost, wrt, consider_constant=None, warn_type=False, def hessian(cost, wrt, consider_constant=None, warn_type=False,
...@@ -811,6 +795,9 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False, ...@@ -811,6 +795,9 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False,
assert cost.ndim == 0, \ assert cost.ndim == 0, \
"tensor.hessian expects a 0 dimensional variable as `cost`" "tensor.hessian expects a 0 dimensional variable as `cost`"
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
if isinstance(wrt, (list, tuple)): if isinstance(wrt, (list, tuple)):
use_list = True use_list = True
wrt = list(wrt) wrt = list(wrt)
...@@ -835,9 +822,4 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False, ...@@ -835,9 +822,4 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False,
sequences=arange(expr.shape[0]), sequences=arange(expr.shape[0]),
non_sequences=[expr, input]) non_sequences=[expr, input])
hessians.append(hess) hessians.append(hess)
if use_list and not isinstance(hessians, (list, tuple)): return format_as(using_list, using_tuple, hessians)
return [hessians]
elif not use_list and isinstance(hessians, (list, tuple)):
return hessians[0]
else:
return hessians
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论