提交 bb7e191a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

logic to deal with while in scan_op

Basically change the code to be aware that the last input can be a condition, and if so stop when that condition becomes True.
上级 6c8f69c3
......@@ -110,7 +110,7 @@ class Scan(Op):
TensorType(
broadcastable = (False,) + o.type.broadcastable
, dtype = o.type.dtype ))
# shared outputs
# shared outputs + possibly the ending condition
for o in outputs[end:]:
if cuda.cuda_available and isinstance(o.type,
cuda.CudaNdarrayType):
......@@ -120,7 +120,8 @@ class Scan(Op):
else:
self.output_types.append( o.type )
if self.as_while:
self.output_types = self.output_types[:-1]
self.destroy_map = {}
if hasattr(self,'inplace') and self.inplace:
......@@ -438,8 +439,12 @@ class Scan(Op):
for idx in xrange(len(other_args)):
input_storage[idx+offset].storage[0] = other_args[idx]
i = 0
cond = True
############## THE MAIN LOOP #########################
for i in xrange(n_steps):
#for i in xrange(n_steps):
while (i< n_steps) and cond:
# sequences over which scan iterates
# 3. collect input slices
for idx in xrange(self.n_seqs):
......@@ -496,11 +501,18 @@ class Scan(Op):
offset += self.n_outs+self.n_nit_sot - self.n_mit_mot
for idx in xrange(self.n_shared_outs):
output_storage[idx+offset].storage[0] = None
# If condition add it to the mix
if self.as_while:
pdx = offset + self.n_shared_outs
output_storage[pdx].storage[0] = None
# 5. compute outputs
t0_fn = time.time()
fn()
dt_fn = time.time() - t0_fn
if self.as_while:
pdx = offset + self.n_shared_outs
cond = output_storage[pdx].storage[0] == 0
t_fn += dt_fn
offset_out = 0
# 5.1 Copy over the values for mit_mot outputs
......@@ -558,13 +570,14 @@ class Scan(Op):
itertools.izip(pos, store_steps)
]
i = i+1
# 6. Check if you need to re-order output buffers
begin = self.n_mit_mot
end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end):
min_tap = self.mintaps[idx]
if ( store_steps[idx] < n_steps-self.mintaps[idx] and
if ( store_steps[idx] < i-self.mintaps[idx] and
pos[idx] < store_steps[idx] ):
pdx = pos[idx]
......@@ -594,8 +607,8 @@ class Scan(Op):
# backpropagation through time. In such a scenarion Scan is
# expected to return 0 for all entries for which the gradient is
# not actually computed
elif store_steps[idx] > n_steps - self.mintaps[idx]:
outs[idx][0][n_steps-self.mintaps[idx]:] = 0
elif store_steps[idx] > i - self.mintaps[idx]:
outs[idx][0][i-self.mintaps[idx]:] = 0
t_call = time.time() - t0_call
......@@ -644,7 +657,10 @@ class Scan(Op):
out_equivalent = {}
for in_ns, out_ns in zip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns
if self.as_while:
self_outs = self.outputs[:-1]
else:
self_outs = self.outputs
outs_shape = scan_utils.infer_shape(
outs = self.outputs,
inputs = self.inputs,
......@@ -738,6 +754,7 @@ class Scan(Op):
+ self.n_mit_sot
+ self.n_nit_sot
+ self.n_sit_sot )
# shared variables as well as the condition as well as the condition
old_scan_shared_outs = self_outputs[out_offset:]
arg_offset = ( 1
+ self.n_seqs
......@@ -992,6 +1009,8 @@ class Scan(Op):
info['n_sit_sot'] = 0
info['n_shared_outs'] = n_shared_outs + self.n_shared_outs
info['n_nit_sot'] = n_nit_sot
info['as_while'] = self.as_while
info['profile'] = self.profile
if self.name:
info['name'] = 'grad_of_' + self.name
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论