提交 e488499d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new optimization that merges outputs that compute similar things

Note: this optimization expects the state of scan to have the as_while flag and the profile flag.
上级 ce842d3d
......@@ -838,6 +838,156 @@ optdb.register('scanOp_merge',
'fast_run',
'scan')
def has_duplicates(l):
"""returns true if l has any duplicates (according to __eq__)."""
return len(set(l)) < len(l)
def make_equiv(lo, li):
"""builds a dictionary of equivalences between inner inputs based on the equivalence of their corresponding outer inputs."""
seeno = {}
left = []
right = []
for o, i in zip(lo, li):
if o in seeno:
left += [i]
right += [o]
else:
seeno[o] = i
return left, right
@gof.local_optimizer([None])
def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan):
return False
a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
inp_equiv = {}
if has_duplicates(a.outer_in_seqs):
new_outer_seqs = []
new_inner_seqs = []
for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs):
if out_seq in new_outer_seqs:
i = new_outer_seqs.index(out_seq)
inp_equiv[in_seq] = new_inner_seqs[i]
else:
new_outer_seqs.append(out_seq)
new_inner_seqs.append(in_seq)
a.outer_in_seqs = new_outer_seqs
a.inner_in_seqs = new_inner_seqs
if has_duplicates(a.outer_in_non_seqs):
new_outer_nseqs = []
new_inner_nseqs = []
for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs):
if out_nseq in new_outer_nseqs:
i = new_outer_nseqs.index(out_nseq)
inp_equiv[in_nseq] = new_inner_nseqs[i]
else:
new_outer_nseqs.append(out_nseq)
new_inner_nseqs.append(in_nseq)
a.outer_in_non_seqs = new_outer_nseqs
a.inner_in_non_seqs = new_inner_nseqs
if len(inp_equiv) > 0:
# do the replacement now. The rest will be left to ScanSaveMem
inner_inputs = a.inner_inputs
outer_inputs = a.outer_inputs
info = a.info
if info['as_while']:
a_inner_outs = a.inner_outputs + a.cond
else:
a_inner_outs = a.inner_outputs
inner_outputs = scan_utils.clone(a_inner_outs, replace=inp_equiv)
orig_outputs = a.outer_outputs
op = scan_op.Scan(inner_inputs, inner_outputs, info)
outputs = op(*outer_inputs)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
else:
na = a
# start again
left = []
right = []
#inp_equiv = [[],[]]
if has_duplicates(na.outer_in_shared):
_left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared)
left += _left
right += _right
#inp_equiv.update(make_equiv(na.outer_in_shared, na.inner_in_shared))
if has_duplicates(na.outer_in_sit_sot):
_left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot)
left += _left
right += _right
#inp_equiv.update(make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot))
if has_duplicates(na.outer_in_mit_mot):
seen = {}
for omm, imm, _sl in zip(na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices):
sl = tuple(_sl)
if (omm, sl) in seen:
simm = seen[(omm, sl)]
left += imm
right += simm
#inp_equiv.update(zip(imm, simm))
else:
seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot):
seen = {}
for oms, ims, _sl in zip(na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices):
sl = tuple(_sl)
if (oms, sl) in seen:
sims = seen[(oms, sl)]
left += ims
right += sims
#inp_equiv.update(zip(ims, sims))
else:
seen[(oms, sl)] = ims
def map_out(i, o, seen):
for si, so in seen:
if equal_computations([i], [si],left, right):
return so
seen.append((i, o))
return o
seen = []
na.outer_out_nit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_nit_sot, na.outer_out_nit_sot)]
seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_sit_sot, na.outer_out_sit_sot)]
seen = []
na.outer_out_mit_sot = [map_out(i, o, seen) for i, o in zip(na.inner_out_mit_sot, na.outer_out_mit_sot)]
seen = []
new_outer_out_mit_mot = []
for imm, omm, osl in zip(na.inner_out_mit_mot, na.outer_out_mit_mot, na.mit_mot_out_slices):
for simm, somm, sosl in seen:
if osl == sosl and equal_computations(imm, simm, left, right):
new_outer_out_mit_mot.append(somm)
break
else:
seen.append((imm, omm, osl))
new_outer_out_mit_mot.append(omm)
na.outer_out_mit_mot = new_outer_out_mit_mot
return na.outer_outputs
optdb.register('scanOp_merge_inouts'
, opt.in2out(scan_merge_inouts,ignore_newtrees=True)
, 1.91
, 'fast_run'
, 'scan')
from theano.sandbox import cuda
if cuda.cuda_available:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论