提交 0dfefe84 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

logic to deal with the condition passed to scan

上级 fd88c988
......@@ -716,7 +716,12 @@ def scan( fn
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
if condition is not None:
as_while = True
else:
as_while = False
##
### Step 3. Check if we actually need scan and remove it if we don't
##
......@@ -725,6 +730,10 @@ def scan( fn
if n_fixed_steps in [1, -1]:
# We do not need to use the scan op anymore, so we can just return
# the outputs and updates we have
if condition is not None:
warning( ('When the number of steps is fixed and equal to 1,'
' the provided stopping condition, ', str(condition),
' is ignored'))
for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an
......@@ -770,8 +779,11 @@ def scan( fn
## in args is quite important
dummy_args += extra_inputs
dummy_outs = outputs
if condition is not None:
dummy_outs.append(condition)
dummy_f = function( dummy_args
, outputs
, dummy_outs
, updates = updates
, mode = compile.mode.Mode(linker='py',
optimizer=None) )
......@@ -789,13 +801,18 @@ def scan( fn
# assumed outputs until now (provided by the user) there can be
# only one explanation: No information is provided for any of the
# outputs (i.e. we are dealing with a map)
if not ( len(dummy_f.maker.outputs) == n_outs or outs_info == []):
tmp_dummy_f_outs = len(dummy_f.maker.outputs)
if as_while:
tmp_dummy_f_outs -= 1
if not ( tmp_dummy_f_outs == n_outs or outs_info == []):
raise ValueError('Please provide None as output_info for '
'any output that does not feed back into '
'scan (i.e. it behaves like a map) ')
if outs_info == []:
n_outs = len(dummy_f.maker.outputs)
if as_while:
n_outs = n_outs - 1
outs_info = [ dict() for x in xrange(n_outs) ]
......@@ -889,6 +906,8 @@ def scan( fn
sit_sot_inner_outputs +
nit_sot_inner_outputs +
shared_inner_outputs )
if condition is not None:
inner_outs.append(condition)
if cuda.cuda_available:
# very often we end up in this situation when we want to
# replace w with w_copy, where w is CudaNdarray
......@@ -930,6 +949,8 @@ def scan( fn
info['mode'] = mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = profile
local_op = scan_op.Scan( inner_inputs, new_outs, info )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论