提交 572c7878 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added function that reformats the output

Following olivier's suggestion, I refractored the code that retransforms the output in either a list or tuple
上级 a018cdea
......@@ -23,6 +23,23 @@ from theano import compile
_logger = logging.getLogger('theano.tensor.tensor_grad')
def format_as(use_list, use_tuple, outputs):
if (use_list or use_tuple) and not isinstance(outputs, (list, tuple)):
if use_list:
return [outputs]
else:
return (outputs,)
elif not (use_list or use_tuple) and isinstance(outputs, (list, tuple)):
return outputs[0]
elif use_list or use_tuple:
if use_list:
return list(outputs)
else:
return tuple(outputs)
else:
return outputs
########################
# R Operator
########################
......@@ -138,16 +155,7 @@ def Rop(f, wrt, eval_points):
else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
if len(rval) == 1:
if using_list:
return rval
if using_tuple:
return tuple(rval)
return rval[0]
else:
if using_tuple:
return tuple(rval)
return rval
return format_as(using_list, using_tuple, rval)
def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
......@@ -227,17 +235,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p))
if len(ret) == 1:
if using_list:
return ret
elif using_tuple:
return tuple(ret)
else:
return ret[0]
else:
if using_tuple:
return tuple(ret)
return ret
return format_as(using_list, using_tuple, ret)
#########################
......@@ -344,16 +342,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.")
ret.append(zeros_like(p))
if len(ret) == 1 and not (using_list or using_tuple):
# `wrt` was a single Variable, so we return a single Variable too.
return ret[0]
else:
# Ensure we preserve the original type of `wrt`.
if using_tuple:
return tuple(ret)
else:
assert using_list
return ret
return format_as(using_list, using_tuple, ret)
class numeric_grad(object):
......@@ -771,12 +760,7 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False,
jacobs, _ = theano.scan(inner_function,
sequences=arange(expression.shape[0]),
non_sequences=[expression] + wrt)
if use_list and not isinstance(jacobs, (list, tuple)):
return [jacobs]
elif not use_list and isinstance(jacobs, (list, tuple)):
return jacobs[0]
else:
return jacobs
return format_as(using_list, using_tuple, jacobs)
def hessian(cost, wrt, consider_constant=None, warn_type=False,
......@@ -811,6 +795,9 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False,
assert cost.ndim == 0, \
"tensor.hessian expects a 0 dimensional variable as `cost`"
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
if isinstance(wrt, (list, tuple)):
use_list = True
wrt = list(wrt)
......@@ -835,9 +822,4 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False,
sequences=arange(expr.shape[0]),
non_sequences=[expr, input])
hessians.append(hess)
if use_list and not isinstance(hessians, (list, tuple)):
return [hessians]
elif not use_list and isinstance(hessians, (list, tuple)):
return hessians[0]
else:
return hessians
return format_as(using_list, using_tuple, hessians)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论