提交 bf96ef6d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new class used by the merge optimization wrote by Arnaud

上级 c451cfb7
......@@ -740,3 +740,178 @@ def reconstruct_graph(inputs, outputs, tag):
nw_outputs = clone( outputs, replace=givens)
return (nw_inputs, nw_outputs)
class scan_args(object):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
def __init__(self, outer_inputs, outer_outputs,
_inner_inputs, _inner_outputs, info):
self.n_steps = outer_inputs[0]
rval = reconstruct_graph(_inner_inputs, _inner_outputs, '_merge')
#if info['as_while']:
# self.cond = [rval[1][-1]]
# inner_outputs = rval[1][:-1]
#else:
inner_outputs = rval[1]
inner_inputs = rval[0]
p = 1
q = 0
n_seqs = info['n_seqs']
self.outer_in_seqs = outer_inputs[p:p+n_seqs]
self.inner_in_seqs = inner_inputs[q:q+n_seqs]
p += n_seqs
q += n_seqs
n_mit_mot = info['n_mit_mot']
n_mit_sot = info['n_mit_sot']
self.mit_mot_in_slices = info['tap_array'][:n_mit_mot]
self.mit_sot_in_slices = info['tap_array'][n_mit_mot:n_mit_mot+n_mit_sot]
n_mit_mot_ins = sum(len(s) for s in self.mit_mot_in_slices)
n_mit_sot_ins = sum(len(s) for s in self.mit_sot_in_slices)
iimm = inner_inputs[q:q+n_mit_mot_ins]
self.inner_in_mit_mot = []
qq = 0
for sl in self.mit_mot_in_slices:
self.inner_in_mit_mot.append(iimm[qq:qq+len(sl)])
qq += len(sl)
q += n_mit_mot_ins
iims = inner_inputs[q:q+n_mit_sot_ins]
self.inner_in_mit_sot = []
qq = 0
for sl in self.mit_sot_in_slices:
self.inner_in_mit_sot.append(iims[qq:qq+len(sl)])
qq += len(sl)
q += n_mit_sot_ins
self.outer_in_mit_mot = outer_inputs[p:p+n_mit_mot]
p += n_mit_mot
self.outer_in_mit_sot = outer_inputs[p:p+n_mit_sot]
p += n_mit_sot
n_sit_sot = info['n_sit_sot']
self.outer_in_sit_sot = outer_inputs[p:p+n_sit_sot]
self.inner_in_sit_sot = inner_inputs[q:q+n_sit_sot]
p += n_sit_sot
q += n_sit_sot
n_shared_outs = info['n_shared_outs']
self.outer_in_shared = outer_inputs[p:p+n_shared_outs]
self.inner_in_shared = inner_inputs[q:q+n_shared_outs]
p += n_shared_outs
q += n_shared_outs
n_nit_sot = info['n_nit_sot']
self.outer_in_nit_sot = outer_inputs[p:p+n_nit_sot]
p += n_nit_sot
self.outer_in_non_seqs = outer_inputs[p:]
self.inner_in_non_seqs = inner_inputs[q:]
# now for the outputs
p = 0
q = 0
self.mit_mot_out_slices = info['mit_mot_out_slices']
n_mit_mot_outs = info['n_mit_mot_outs']
self.outer_out_mit_mot = outer_outputs[p:p+n_mit_mot]
iomm = inner_outputs[q:q+n_mit_mot_outs]
self.inner_out_mit_mot = []
qq = 0
for sl in self.mit_mot_out_slices:
self.inner_out_mit_mot.append(iomm[qq:qq+len(sl)])
qq += len(sl)
p += n_mit_mot
q += n_mit_mot_outs
self.outer_out_mit_sot = outer_outputs[p:p+n_mit_sot]
self.inner_out_mit_sot = inner_outputs[q:q+n_mit_sot]
p += n_mit_sot
q += n_mit_sot
self.outer_out_sit_sot = outer_outputs[p:p+n_sit_sot]
self.inner_out_sit_sot = inner_outputs[q:q+n_sit_sot]
p += n_sit_sot
q += n_sit_sot
self.outer_out_nit_sot = outer_outputs[p:p+n_nit_sot]
self.inner_out_nit_sot = inner_outputs[q:q+n_nit_sot]
p += n_nit_sot
q += n_nit_sot
self.outer_out_shared = outer_outputs[p:p+n_shared_outs]
self.inner_out_shared = inner_outputs[q:q+n_shared_outs]
p += n_shared_outs
q += n_shared_outs
self.other_info = dict()
for k in ('truncate_gradient', 'name', 'mode', 'inplace',
'gpu', 'profile'):
self.other_info[k] = info[k]
inner_inputs = property(lambda self: (self.inner_in_seqs +
flatten(self.inner_in_mit_mot) +
flatten(self.inner_in_mit_sot) +
self.inner_in_sit_sot +
self.inner_in_shared +
self.inner_in_non_seqs))
outer_inputs = property(lambda self: ([self.n_steps] +
self.outer_in_seqs +
self.outer_in_mit_mot +
self.outer_in_mit_sot +
self.outer_in_sit_sot +
self.outer_in_shared +
self.outer_in_nit_sot +
self.outer_in_non_seqs))
inner_outputs = property(lambda self: (flatten(self.inner_out_mit_mot) +
self.inner_out_mit_sot +
self.inner_out_sit_sot +
self.inner_out_nit_sot +
self.inner_out_shared))
outer_outputs = property(lambda self: (self.outer_out_mit_mot +
self.outer_out_mit_sot +
self.outer_out_sit_sot +
self.outer_out_nit_sot +
self.outer_out_shared))
info = property(lambda self: dict(n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
tap_array=(self.mit_mot_in_slices +
self.mit_sot_in_slices +
[[-1]] * len(self.inner_in_sit_sot)),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info))
def __copy__(self):
res = object.__new__(type(self))
res.__dict__.update(self.__dict__)
# also copy mutable attrs
for attr in self.__dict__:
if (attr.startswith('inner_in') or attr.startswith('inner_out') or
attr.startswith('outer_in') or attr.startswith('outer_out') or
attr in ('mit_mot_out_slices', 'mit_mot_in_slices',
'mit_sot_in_slices', 'other_info')):
setattr(res, attr, copy.copy(getattr(self, attr)))
return res
def merge(self, other):
res = copy.copy(self)
for attr in self.__dict__:
if (attr.startswith('inner_in') or attr.startswith('inner_out') or
attr.startswith('outer_in') or attr.startswith('outer_out') or
attr in ('mit_mot_out_slices', 'mit_mot_in_slices',
'mit_sot_in_slices')):
getattr(res, attr).extend(getattr(other, attr))
return res
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论