提交 1f03c474 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

add logic to deal with scan as a while and profiling

上级 8c8cf101
......@@ -597,6 +597,8 @@ def compress_outs(op, not_required, inputs):
info['inplace'] = op.info['inplace']
info['gpu'] = op.info['gpu']
info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
......@@ -705,6 +707,10 @@ def compress_outs(op, not_required, inputs):
# other stuff
op_inputs += op.inputs[i_offset:]
node_inputs += inputs[ni_offset+op.n_shared_outs+op.n_nit_sot:]
if op.as_while:
op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs)-1
#map_old_new[len(op_outputs)-1] = o_offset
return (op_inputs, op_outputs, info, node_inputs, map_old_new)
......@@ -748,11 +754,11 @@ class scan_args(object):
_inner_inputs, _inner_outputs, info):
self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge')
#if info['as_while']:
# self.cond = [rval[1][-1]]
# inner_outputs = rval[1][:-1]
#else:
inner_outputs = rval[1]
if info['as_while']:
self.cond = [rval[1][-1]]
inner_outputs = rval[1][:-1]
else:
inner_outputs = rval[1]
inner_inputs = rval[0]
p = 1
......@@ -852,7 +858,7 @@ class scan_args(object):
self.other_info = dict()
for k in ('truncate_gradient', 'name', 'mode', 'inplace',
'gpu', 'profile'):
'gpu','as_while', 'profile'):
self.other_info[k] = info[k]
inner_inputs = property(lambda self: (self.inner_in_seqs +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论