提交 bb7e191a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

logic to deal with while in scan_op

Basically change the code to be aware that the last input can be a condition, and if so stop when that condition becomes True.
上级 6c8f69c3
...@@ -110,7 +110,7 @@ class Scan(Op): ...@@ -110,7 +110,7 @@ class Scan(Op):
TensorType( TensorType(
broadcastable = (False,) + o.type.broadcastable broadcastable = (False,) + o.type.broadcastable
, dtype = o.type.dtype )) , dtype = o.type.dtype ))
# shared outputs # shared outputs + possibly the ending condition
for o in outputs[end:]: for o in outputs[end:]:
if cuda.cuda_available and isinstance(o.type, if cuda.cuda_available and isinstance(o.type,
cuda.CudaNdarrayType): cuda.CudaNdarrayType):
...@@ -120,7 +120,8 @@ class Scan(Op): ...@@ -120,7 +120,8 @@ class Scan(Op):
else: else:
self.output_types.append( o.type ) self.output_types.append( o.type )
if self.as_while:
self.output_types = self.output_types[:-1]
self.destroy_map = {} self.destroy_map = {}
if hasattr(self,'inplace') and self.inplace: if hasattr(self,'inplace') and self.inplace:
...@@ -438,8 +439,12 @@ class Scan(Op): ...@@ -438,8 +439,12 @@ class Scan(Op):
for idx in xrange(len(other_args)): for idx in xrange(len(other_args)):
input_storage[idx+offset].storage[0] = other_args[idx] input_storage[idx+offset].storage[0] = other_args[idx]
i = 0
cond = True
############## THE MAIN LOOP ######################### ############## THE MAIN LOOP #########################
for i in xrange(n_steps): #for i in xrange(n_steps):
while (i< n_steps) and cond:
# sequences over which scan iterates # sequences over which scan iterates
# 3. collect input slices # 3. collect input slices
for idx in xrange(self.n_seqs): for idx in xrange(self.n_seqs):
...@@ -496,11 +501,18 @@ class Scan(Op): ...@@ -496,11 +501,18 @@ class Scan(Op):
offset += self.n_outs+self.n_nit_sot - self.n_mit_mot offset += self.n_outs+self.n_nit_sot - self.n_mit_mot
for idx in xrange(self.n_shared_outs): for idx in xrange(self.n_shared_outs):
output_storage[idx+offset].storage[0] = None output_storage[idx+offset].storage[0] = None
# If condition add it to the mix
if self.as_while:
pdx = offset + self.n_shared_outs
output_storage[pdx].storage[0] = None
# 5. compute outputs # 5. compute outputs
t0_fn = time.time() t0_fn = time.time()
fn() fn()
dt_fn = time.time() - t0_fn dt_fn = time.time() - t0_fn
if self.as_while:
pdx = offset + self.n_shared_outs
cond = output_storage[pdx].storage[0] == 0
t_fn += dt_fn t_fn += dt_fn
offset_out = 0 offset_out = 0
# 5.1 Copy over the values for mit_mot outputs # 5.1 Copy over the values for mit_mot outputs
...@@ -558,13 +570,14 @@ class Scan(Op): ...@@ -558,13 +570,14 @@ class Scan(Op):
itertools.izip(pos, store_steps) itertools.izip(pos, store_steps)
] ]
i = i+1
# 6. Check if you need to re-order output buffers # 6. Check if you need to re-order output buffers
begin = self.n_mit_mot begin = self.n_mit_mot
end = self.n_outs + self.n_nit_sot end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end): for idx in xrange(begin, end):
min_tap = self.mintaps[idx] min_tap = self.mintaps[idx]
if ( store_steps[idx] < n_steps-self.mintaps[idx] and if ( store_steps[idx] < i-self.mintaps[idx] and
pos[idx] < store_steps[idx] ): pos[idx] < store_steps[idx] ):
pdx = pos[idx] pdx = pos[idx]
...@@ -594,8 +607,8 @@ class Scan(Op): ...@@ -594,8 +607,8 @@ class Scan(Op):
# backpropagation through time. In such a scenarion Scan is # backpropagation through time. In such a scenarion Scan is
# expected to return 0 for all entries for which the gradient is # expected to return 0 for all entries for which the gradient is
# not actually computed # not actually computed
elif store_steps[idx] > n_steps - self.mintaps[idx]: elif store_steps[idx] > i - self.mintaps[idx]:
outs[idx][0][n_steps-self.mintaps[idx]:] = 0 outs[idx][0][i-self.mintaps[idx]:] = 0
t_call = time.time() - t0_call t_call = time.time() - t0_call
...@@ -644,7 +657,10 @@ class Scan(Op): ...@@ -644,7 +657,10 @@ class Scan(Op):
out_equivalent = {} out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]): for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns out_equivalent[in_ns] = out_ns
if self.as_while:
self_outs = self.outputs[:-1]
else:
self_outs = self.outputs
outs_shape = scan_utils.infer_shape( outs_shape = scan_utils.infer_shape(
outs = self.outputs, outs = self.outputs,
inputs = self.inputs, inputs = self.inputs,
...@@ -738,6 +754,7 @@ class Scan(Op): ...@@ -738,6 +754,7 @@ class Scan(Op):
+ self.n_mit_sot + self.n_mit_sot
+ self.n_nit_sot + self.n_nit_sot
+ self.n_sit_sot ) + self.n_sit_sot )
# shared variables as well as the condition as well as the condition
old_scan_shared_outs = self_outputs[out_offset:] old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = ( 1 arg_offset = ( 1
+ self.n_seqs + self.n_seqs
...@@ -992,6 +1009,8 @@ class Scan(Op): ...@@ -992,6 +1009,8 @@ class Scan(Op):
info['n_sit_sot'] = 0 info['n_sit_sot'] = 0
info['n_shared_outs'] = n_shared_outs + self.n_shared_outs info['n_shared_outs'] = n_shared_outs + self.n_shared_outs
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['profile'] = self.profile
if self.name: if self.name:
info['name'] = 'grad_of_' + self.name info['name'] = 'grad_of_' + self.name
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论