提交 e345e095 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5420 from nouiz/L_op

Use Scan.L_op instead of Scan.grad() to help speed up the second deri…
...@@ -81,7 +81,8 @@ def scan(fn, ...@@ -81,7 +81,8 @@ def scan(fn,
name=None, name=None,
profile=False, profile=False,
allow_gc=None, allow_gc=None,
strict=False): strict=False,
return_list=False):
""" """
This function constructs and applies a Scan op to the provided This function constructs and applies a Scan op to the provided
arguments. arguments.
...@@ -333,6 +334,9 @@ def scan(fn, ...@@ -333,6 +334,9 @@ def scan(fn,
If true, all the shared variables used in ``fn`` must be provided as a If true, all the shared variables used in ``fn`` must be provided as a
part of ``non_sequences`` or ``sequences``. part of ``non_sequences`` or ``sequences``.
return_list
If True, will always return a list, even if there is only 1 output.
Returns Returns
------- -------
tuple tuple
...@@ -794,7 +798,8 @@ def scan(fn, ...@@ -794,7 +798,8 @@ def scan(fn,
return_steps.get(pos, 0) != 1): return_steps.get(pos, 0) != 1):
outputs[pos] = tensor.unbroadcast( outputs[pos] = tensor.unbroadcast(
tensor.shape_padleft(inner_out), 0) tensor.shape_padleft(inner_out), 0)
if len(outputs) == 1:
if return_list is not True and len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
return (outputs, updates) return (outputs, updates)
...@@ -1134,8 +1139,9 @@ def scan(fn, ...@@ -1134,8 +1139,9 @@ def scan(fn,
# refers to update rule of index -1 - `pos`. # refers to update rule of index -1 - `pos`.
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None] scan_out_list = [x for x in scan_out_list if x is not None]
if len(scan_out_list) == 1: if return_list is not True and len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0]
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
return (scan_out_list, update_map) return (scan_out_list, update_map)
...@@ -1931,8 +1931,7 @@ class Scan(PureOp): ...@@ -1931,8 +1931,7 @@ class Scan(PureOp):
return mappings return mappings
# GRAD FUNCTION # GRAD FUNCTION
def grad(self, inputs, dC_douts): def L_op(self, inputs, outs, dC_douts):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)): if not isinstance(outs, (list, tuple)):
outs = [outs] outs = [outs]
# `grad_step` equals the number of steps the original scan node has # `grad_step` equals the number of steps the original scan node has
......
...@@ -318,12 +318,14 @@ class T_Scan(unittest.TestCase): ...@@ -318,12 +318,14 @@ class T_Scan(unittest.TestCase):
state = theano.tensor.scalar('state') state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps') n_steps = theano.tensor.iscalar('nsteps')
# Test return_list at the same time.
output, updates = theano.scan(f_pow2, output, updates = theano.scan(f_pow2,
[], [],
state, state,
[], [],
n_steps=n_steps, n_steps=n_steps,
truncate_gradient=-1, truncate_gradient=-1,
return_list=True,
go_backwards=False) go_backwards=False)
my_f = theano.function([state, n_steps], my_f = theano.function([state, n_steps],
output, output,
...@@ -337,7 +339,7 @@ class T_Scan(unittest.TestCase): ...@@ -337,7 +339,7 @@ class T_Scan(unittest.TestCase):
numpy_values = numpy.array([state * (2 ** (k + 1)) for k numpy_values = numpy.array([state * (2 ** (k + 1)) for k
in xrange(steps)]) in xrange(steps)])
theano_values = my_f(state, steps) theano_values = my_f(state, steps)
utt.assert_allclose(numpy_values, theano_values) utt.assert_allclose(numpy_values, theano_values[0])
def test_subtensor_multiple_slices(self): def test_subtensor_multiple_slices(self):
# This addresses a bug reported by Matthias Zoehrer # This addresses a bug reported by Matthias Zoehrer
...@@ -4416,16 +4418,17 @@ class T_Scan(unittest.TestCase): ...@@ -4416,16 +4418,17 @@ class T_Scan(unittest.TestCase):
n_steps=1, n_steps=1,
) )
return sum_outer + result_inner[-1] return sum_outer + result_inner[-1]
# Also test return_list for that case.
result_outer, _ = theano.scan( result_outer, _ = theano.scan(
fn=loss_outer, fn=loss_outer,
outputs_info=tensor.as_tensor_variable( outputs_info=tensor.as_tensor_variable(
numpy.asarray(0, dtype=numpy.float32)), numpy.asarray(0, dtype=numpy.float32)),
non_sequences=[W], non_sequences=[W],
n_steps=n_steps, n_steps=n_steps,
return_list=True,
) )
cost = result_outer[-1] cost = result_outer[0][-1]
H = theano.gradient.hessian(cost, W) H = theano.gradient.hessian(cost, W)
print(".", file=sys.stderr) print(".", file=sys.stderr)
f = theano.function([W, n_steps], H) f = theano.function([W, n_steps], H)
......
...@@ -643,73 +643,70 @@ def test_scan_debugprint5(): ...@@ -643,73 +643,70 @@ def test_scan_debugprint5():
| |Subtensor{::int64} [id BL] '' | |Subtensor{::int64} [id BL] ''
| | |IncSubtensor{Inc;int64::} [id BM] '' | | |IncSubtensor{Inc;int64::} [id BM] ''
| | | |Elemwise{second,no_inplace} [id BN] '' | | | |Elemwise{second,no_inplace} [id BN] ''
| | | | |for{cpu,scan_fn} [id BO] '' | | | | |for{cpu,scan_fn} [id F] ''
| | | | | |k [id G] | | | | |InplaceDimShuffle{x,x} [id BO] ''
| | | | | |IncSubtensor{Set;:int64:} [id H] '' | | | | |TensorConstant{0.0} [id BP]
| | | | | |A [id P] | | | |IncSubtensor{Inc;int64} [id BQ] ''
| | | | |InplaceDimShuffle{x,x} [id BP] '' | | | | |Elemwise{second,no_inplace} [id BR] ''
| | | | |TensorConstant{0.0} [id BQ] | | | | | |Subtensor{int64::} [id BS] ''
| | | |IncSubtensor{Inc;int64} [id BR] '' | | | | | | |for{cpu,scan_fn} [id F] ''
| | | | |Elemwise{second,no_inplace} [id BS] '' | | | | | | |Constant{1} [id BT]
| | | | | |Subtensor{int64::} [id BT] '' | | | | | |InplaceDimShuffle{x,x} [id BU] ''
| | | | | | |for{cpu,scan_fn} [id BO] '' | | | | | |TensorConstant{0.0} [id BP]
| | | | | | |Constant{1} [id BU] | | | | |Elemwise{second} [id BV] ''
| | | | | |InplaceDimShuffle{x,x} [id BV] '' | | | | | |Subtensor{int64} [id BW] ''
| | | | | |TensorConstant{0.0} [id BQ] | | | | | | |Subtensor{int64::} [id BS] ''
| | | | |Elemwise{second} [id BW] '' | | | | | | |Constant{-1} [id BX]
| | | | | |Subtensor{int64} [id BX] '' | | | | | |InplaceDimShuffle{x} [id BY] ''
| | | | | | |Subtensor{int64::} [id BT] '' | | | | | |Elemwise{second,no_inplace} [id BZ] ''
| | | | | | |Constant{-1} [id BY] | | | | | |Sum{acc_dtype=float64} [id CA] ''
| | | | | |InplaceDimShuffle{x} [id BZ] '' | | | | | | |Subtensor{int64} [id BW] ''
| | | | | |Elemwise{second,no_inplace} [id CA] ''
| | | | | |Sum{acc_dtype=float64} [id CB] ''
| | | | | | |Subtensor{int64} [id BX] ''
| | | | | |TensorConstant{1.0} [id R] | | | | | |TensorConstant{1.0} [id R]
| | | | |Constant{-1} [id BY] | | | | |Constant{-1} [id BX]
| | | |Constant{1} [id BU] | | | |Constant{1} [id BT]
| | |Constant{-1} [id CC] | | |Constant{-1} [id CB]
| |Alloc [id CD] '' | |Alloc [id CC] ''
| | |TensorConstant{0.0} [id BQ] | | |TensorConstant{0.0} [id BP]
| | |Elemwise{add,no_inplace} [id CE] '' | | |Elemwise{add,no_inplace} [id CD] ''
| | | |Elemwise{sub,no_inplace} [id C] '' | | | |Elemwise{sub,no_inplace} [id C] ''
| | | |TensorConstant{1} [id Y] | | | |TensorConstant{1} [id Y]
| | |Subtensor{int64} [id CF] '' | | |Subtensor{int64} [id CE] ''
| | |Shape [id CG] '' | | |Shape [id CF] ''
| | | |A [id P] | | | |A [id P]
| | |Constant{0} [id CH] | | |Constant{0} [id CG]
| |A [id P] | |A [id P]
|Constant{-1} [id CI] |Constant{-1} [id CH]
Inner graphs of the scan ops: Inner graphs of the scan ops:
for{cpu,grad_of_scan_fn}.1 [id B] '' for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id CJ] '' >Elemwise{add,no_inplace} [id CI] ''
> |Elemwise{mul} [id CK] '' > |Elemwise{mul} [id CJ] ''
> | |<TensorType(float64, vector)> [id CL] -> [id BL] > | |<TensorType(float64, vector)> [id CK] -> [id BL]
> | |A_copy [id CM] -> [id P] > | |A_copy [id CL] -> [id P]
> |<TensorType(float64, vector)> [id CN] -> [id BL] > |<TensorType(float64, vector)> [id CM] -> [id BL]
>Elemwise{add,no_inplace} [id CO] '' >Elemwise{add,no_inplace} [id CN] ''
> |Elemwise{mul} [id CP] '' > |Elemwise{mul} [id CO] ''
> | |<TensorType(float64, vector)> [id CL] -> [id BL] > | |<TensorType(float64, vector)> [id CK] -> [id BL]
> | |<TensorType(float64, vector)> [id CQ] -> [id Z] > | |<TensorType(float64, vector)> [id CP] -> [id Z]
> |<TensorType(float64, vector)> [id CR] -> [id CD] > |<TensorType(float64, vector)> [id CQ] -> [id CC]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CS] '' >Elemwise{mul,no_inplace} [id CR] ''
> |<TensorType(float64, vector)> [id CT] -> [id H] > |<TensorType(float64, vector)> [id CS] -> [id H]
> |A_copy [id CU] -> [id P] > |A_copy [id CT] -> [id P]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CS] '' >Elemwise{mul,no_inplace} [id CR] ''
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CS] '' >Elemwise{mul,no_inplace} [id CR] ''
for{cpu,scan_fn} [id BO] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CS] '' >Elemwise{mul,no_inplace} [id CR] ''
for{cpu,scan_fn} [id BO] '' for{cpu,scan_fn} [id F] ''
>Elemwise{mul,no_inplace} [id CS] ''""" >Elemwise{mul,no_inplace} [id CR] ''"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论