提交 1994dbbc authored 作者: Frederic Bastien's avatar Frederic Bastien

Add return_list=True to scan() function

上级 0f5a06d7
...@@ -81,7 +81,8 @@ def scan(fn, ...@@ -81,7 +81,8 @@ def scan(fn,
name=None, name=None,
profile=False, profile=False,
allow_gc=None, allow_gc=None,
strict=False): strict=False,
return_list=False):
""" """
This function constructs and applies a Scan op to the provided This function constructs and applies a Scan op to the provided
arguments. arguments.
...@@ -333,6 +334,9 @@ def scan(fn, ...@@ -333,6 +334,9 @@ def scan(fn,
If true, all the shared variables used in ``fn`` must be provided as a If true, all the shared variables used in ``fn`` must be provided as a
part of ``non_sequences`` or ``sequences``. part of ``non_sequences`` or ``sequences``.
return_list
If True, will always return a list, even if there is only 1 output.
Returns Returns
------- -------
tuple tuple
...@@ -794,7 +798,8 @@ def scan(fn, ...@@ -794,7 +798,8 @@ def scan(fn,
return_steps.get(pos, 0) != 1): return_steps.get(pos, 0) != 1):
outputs[pos] = tensor.unbroadcast( outputs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out), 0) tensor.shape_padleft(inner_out), 0)
if len(outputs) == 1:
if return_list is not True and len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
return (outputs, updates) return (outputs, updates)
...@@ -1134,8 +1139,9 @@ def scan(fn, ...@@ -1134,8 +1139,9 @@ def scan(fn,
# refers to update rule of index -1 - `pos`. # refers to update rule of index -1 - `pos`.
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None] scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1: if return_list is not True and len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
return (scan_out_list, update_map) return (scan_out_list, update_map)
...@@ -318,12 +318,14 @@ class T_Scan(unittest.TestCase): ...@@ -318,12 +318,14 @@ class T_Scan(unittest.TestCase):
state = theano.tensor.scalar('state') state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps') n_steps = theano.tensor.iscalar('nsteps')
# Test return_list at the same time.
output, updates = theano.scan(f_pow2, output, updates = theano.scan(f_pow2,
[], [],
state, state,
[], [],
n_steps=n_steps, n_steps=n_steps,
truncate_gradient=-1, truncate_gradient=-1,
return_list=True,
go_backwards=False) go_backwards=False)
my_f = theano.function([state, n_steps], my_f = theano.function([state, n_steps],
output, output,
...@@ -337,7 +339,7 @@ class T_Scan(unittest.TestCase): ...@@ -337,7 +339,7 @@ class T_Scan(unittest.TestCase):
numpy_values = numpy.array([state * (2 ** (k + 1)) for k numpy_values = numpy.array([state * (2 ** (k + 1)) for k
in xrange(steps)]) in xrange(steps)])
theano_values = my_f(state, steps) theano_values = my_f(state, steps)
utt.assert_allclose(numpy_values, theano_values) utt.assert_allclose(numpy_values, theano_values[0])
def test_subtensor_multiple_slices(self): def test_subtensor_multiple_slices(self):
# This addresses a bug reported by Matthias Zoehrer # This addresses a bug reported by Matthias Zoehrer
...@@ -4416,16 +4418,17 @@ class T_Scan(unittest.TestCase): ...@@ -4416,16 +4418,17 @@ class T_Scan(unittest.TestCase):
n_steps=1, n_steps=1,
) )
return sum_outer + result_inner[-1] return sum_outer + result_inner[-1]
# Also test return_list for that case.
result_outer, _ = theano.scan( result_outer, _ = theano.scan(
fn=loss_outer, fn=loss_outer,
outputs_info=tensor.as_tensor_variable( outputs_info=tensor.as_tensor_variable(
numpy.asarray(0, dtype=numpy.float32)), numpy.asarray(0, dtype=numpy.float32)),
non_sequences=[W], non_sequences=[W],
n_steps=n_steps, n_steps=n_steps,
return_list=True,
) )
cost = result_outer[-1] cost = result_outer[0][-1]
H = theano.gradient.hessian(cost, W) H = theano.gradient.hessian(cost, W)
print(".", file=sys.stderr) print(".", file=sys.stderr)
f = theano.function([W, n_steps], H) f = theano.function([W, n_steps], H)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论