提交 1f03c474 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

add logic to deal with scan as a while and profiling

上级 8c8cf101
...@@ -597,6 +597,8 @@ def compress_outs(op, not_required, inputs): ...@@ -597,6 +597,8 @@ def compress_outs(op, not_required, inputs):
info['inplace'] = op.info['inplace'] info['inplace'] = op.info['inplace']
info['gpu'] = op.info['gpu'] info['gpu'] = op.info['gpu']
info['mode'] = op.info['mode'] info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
op_inputs = op.inputs[:op.n_seqs] op_inputs = op.inputs[:op.n_seqs]
op_outputs = [] op_outputs = []
...@@ -705,6 +707,10 @@ def compress_outs(op, not_required, inputs): ...@@ -705,6 +707,10 @@ def compress_outs(op, not_required, inputs):
# other stuff # other stuff
op_inputs += op.inputs[i_offset:] op_inputs += op.inputs[i_offset:]
node_inputs += inputs[ni_offset+op.n_shared_outs+op.n_nit_sot:] node_inputs += inputs[ni_offset+op.n_shared_outs+op.n_nit_sot:]
if op.as_while:
op_outputs += [op.outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs)-1
#map_old_new[len(op_outputs)-1] = o_offset
return (op_inputs, op_outputs, info, node_inputs, map_old_new) return (op_inputs, op_outputs, info, node_inputs, map_old_new)
...@@ -748,10 +754,10 @@ class scan_args(object): ...@@ -748,10 +754,10 @@ class scan_args(object):
_inner_inputs, _inner_outputs, info): _inner_inputs, _inner_outputs, info):
self.n_steps = outer_inputs[0] self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge') rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge')
#if info['as_while']: if info['as_while']:
# self.cond = [rval[1][-1]] self.cond = [rval[1][-1]]
# inner_outputs = rval[1][:-1] inner_outputs = rval[1][:-1]
#else: else:
inner_outputs = rval[1] inner_outputs = rval[1]
inner_inputs = rval[0] inner_inputs = rval[0]
...@@ -852,7 +858,7 @@ class scan_args(object): ...@@ -852,7 +858,7 @@ class scan_args(object):
self.other_info = dict() self.other_info = dict()
for k in ('truncate_gradient', 'name', 'mode', 'inplace', for k in ('truncate_gradient', 'name', 'mode', 'inplace',
'gpu', 'profile'): 'gpu','as_while', 'profile'):
self.other_info[k] = info[k] self.other_info[k] = info[k]
inner_inputs = property(lambda self: (self.inner_in_seqs + inner_inputs = property(lambda self: (self.inner_in_seqs +
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论