提交 f1670422 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

rewrote implementation that merges scan ops

The old implementation used to result in stochastic order error in debugmode. After many attempts to solve it, I decided that it would be better and faster just to rewrite it. This new implementation does not suffer from any bug (i.e. all tests pass in debug mode).
上级 9c6da45e
......@@ -23,6 +23,7 @@ from theano import gof
from theano.compile import optdb
from theano.gof.opt import EquilibriumOptimizer
from theano import config
from theano.compile.function_module import deep_copy_op
import scan_op
import scan_utils
......@@ -772,82 +773,194 @@ class ScanMerge(gof.Optimizer):
def add_requirements(self,env):
env.extend(gof.toolbox.ReplaceValidate())
def merge(self, A,B, as_while):
Aargs = scan_args(A.inputs, A.outputs, A.op.inputs, A.op.outputs, A.op.info)
Bargs = scan_args(B.inputs, B.outputs, B.op.inputs, B.op.outputs, B.op.info)
Margs = Aargs.merge(Bargs)
def merge(self, nodes):
# fixup name
info = Margs.info
info['name'] = A.op.name+'&'+B.op.name
if nodes[0].op.as_while:
as_while = True
condition = nodes[0].op.outputs[-1]
else:
as_while = False
info = {}
info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
info['n_mit_mot_outs'] = sum([nd.op.n_mit_mot_outs for nd in nodes])
info['mit_mot_out_slices'] = []
info['n_mit_sot'] = sum([nd.op.n_mit_sot for nd in nodes])
info['n_sit_sot'] = sum([nd.op.n_sit_sot for nd in nodes])
info['n_shared_outs'] = sum([nd.op.n_shared_outs for nd in nodes])
info['n_nit_sot'] = sum([nd.op.n_nit_sot for nd in nodes])
info['truncate_gradient'] = nodes[0].op.truncate_gradient
info['name'] = '&'.join([nd.op.name for nd in nodes])
info['mode'] = nodes[0].op.mode
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = nodes[0].op.profile
inner_ins = []
outer_ins = []
inner_outs = []
outer_outs = []
def rename(ls, suffix):
for k in ls:
if k.name:
k.name += str(suffix)
return ls
for idx,nd in enumerate(nodes):
# Seq
inner_ins += rename(nd.op.inner_seqs(),idx)
outer_ins += rename(nd.op.outer_seqs(nd),idx)
for idx,nd in enumerate(nodes):
# MitMot
inner_ins += rename(nd.op.inner_mitmot(),idx)
inner_outs += nd.op.inner_mitmot_outs()
info['tap_array'] += nd.op.mitmot_taps()
info['mit_mot_out_slices'] += nd.op.mitmot_out_taps()
outer_ins += rename(nd.op.outer_mitmot(nd),idx)
outer_outs += nd.op.outer_mitmot_outs(nd)
for idx,nd in enumerate(nodes):
# MitSot
inner_ins += rename(nd.op.inner_mitsot(),idx)
inner_outs += nd.op.inner_mitsot_outs()
info['tap_array'] += nd.op.mitsot_taps()
outer_ins += rename(nd.op.outer_mitsot(nd),idx)
outer_outs += nd.op.outer_mitsot_outs(nd)
for idx,nd in enumerate(nodes):
# SitSot
inner_ins += rename(nd.op.inner_sitsot(),idx)
info['tap_array'] += [[-1] for x in xrange(nd.op.n_sit_sot)]
inner_outs += nd.op.inner_sitsot_outs()
outer_ins += rename(nd.op.outer_sitsot(nd),idx)
outer_outs += nd.op.outer_sitsot_outs(nd)
for idx,nd in enumerate(nodes):
# Shared
inner_ins += rename(nd.op.inner_shared(),idx)
outer_ins += rename(nd.op.outer_shared(nd),idx)
for idx,nd in enumerate(nodes):
# NitSot
inner_outs += nd.op.inner_nitsot_outs()
outer_ins += rename(nd.op.outer_nitsot(nd),idx)
outer_outs += nd.op.outer_nitsot_outs(nd)
for idx,nd in enumerate(nodes):
# Shared
outer_outs += nd.op.outer_shared_outs(nd)
inner_outs += nd.op.inner_shared_outs()
for idx,nd in enumerate(nodes):
# Non Seqs
inner_ins += rename(nd.op.inner_non_seqs(),idx)
outer_ins += rename(nd.op.outer_non_seqs(nd),idx)
# Add back the number of steps
outer_ins = [nodes[0].inputs[0]] + outer_ins
#indicates that we have a stopping condition for scan
if as_while:
Margs_inner_outs = Margs.inner_outputs + Margs.cond
else:
Margs_inner_outs = Margs.inner_outputs
op = scan_op.Scan(Margs.inner_inputs, Margs_inner_outs, info)
# add the condition
inner_outs.append(condition)
inner_ins, inner_outs = scan_utils.reconstruct_graph(inner_ins,
inner_outs)
outputs = op(*Margs.outer_inputs)
new_op = scan_op.Scan(inner_ins, inner_outs, info)
new_outs = new_op(*outer_ins)
if type(outputs) not in (list, tuple):
outputs = [outputs]
if not isinstance(new_outs, (list, tuple)):
new_outs = [new_outs]
return zip(Margs.outer_outputs, outputs)
return zip(outer_outs, new_outs)
def apply(self, env):
nodelist = list(env.toposort())
scan_nodes = filter(lambda s: isinstance(s.op, scan_op.Scan), nodelist)
nscan = dict()
for snode in scan_nodes:
n_steps = snode.inputs[0]
try:
n_steps = int(get_constant_value(n_steps))
except TypeError:
pass
l = nscan.get(n_steps)
if l is None:
nscan[n_steps] = [snode]
def belongs_to_set(self, node, set_nodes):
"""
This function checks if node `node` belongs to `set_nodes`, in the
sense that it can be merged together with every other node in
`set_nodes`. In order for two nodes to be mergeable, they have to go
over the same number of steps, have the same condition (if any),
have the same value for truncate_gradient, and have the same mode.
Questionable, we should also consider profile ?
"""
rep = set_nodes[0]
if not rep.op.as_while and node.op.as_while:
return False
nsteps = node.inputs[0]
try:
nsteps = int(get_constant_value(nsteps))
except TypeError:
pass
rep_nsteps = rep.inputs[0]
try:
rep_nsteps = int(get_constant_value(rep_nsteps))
except TypeError:
pass
# Check to see if it is an input of a different node
can_add = True
for nd in set_nodes:
if find_up(node, nd) or find_up(nd, node):
can_add = False
can_add = can_add and (node.op.truncate_gradient ==
rep.op.truncate_gradient)
can_add = can_add and (node.op.mode == rep.op.mode)
if not node.op.as_while:
return nsteps == rep_nsteps and can_add
cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1]
same_cond = scan_utils.equal_computations([cond], [rep_cond],
node.op.inputs,
rep.op.inputs)
return same_cond and (nsteps == rep_nsteps) and can_add
def apply(self, env):
# Collect all scan nodes ordered according to toposort
scan_nodes = [ nd for nd in env.toposort()
if isinstance(nd.op, scan_op.Scan)]
# All sets of possibly mergeable nodes
all_sets = []
for nd in scan_nodes:
belongs_to_set_idx = -1
for pos,subset in enumerate(all_sets):
if self.belongs_to_set(nd, subset):
assert belongs_to_set_idx == -1
belongs_to_set_idx = pos
if belongs_to_set_idx == -1:
all_sets.append([nd])
else:
l.append(snode)
for snodes in nscan.values():
if len(snodes) > 1:
# amongst nodes that have the same number of steps
# try to find the ones that can be merged
curnode = snodes[0]
for snode in snodes[1:]:
if (snode.op.truncate_gradient == curnode.op.truncate_gradient and
snode.op.mode == curnode.op.mode and
not find_up(snode, curnode)):
if (not snode.op.as_while and
not curnode.op.as_while):
proposal = self.merge(curnode, snode, False)
env.replace_all_validate(proposal, reason='scan merge')
elif (snode.op.as_while and
curnode.op.as_while):
# check if equal computations
if scan_utils.equal_computations(
[snode.op.outputs[-1]],
[curnode.op.outputs[-1]],
snode.op.inputs,
curnode.op.inputs):
proposal = self.merge(curnode, snode, True)
env.replace_all_validate(proposal, reason =
'scan_merge')
else:
pass
else:
pass
# other merges will be done in other passes
break
all_sets[belongs_to_set_idx].append(nd)
for subset in all_sets:
if len(subset) > 1:
proposal = self.merge(subset)
env.replace_all_validate(proposal, reason = 'scan_merge')
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
optdb.register('scanOp_merge',
EquilibriumOptimizer([ScanMerge()],
max_use_ratio=11),
ScanMerge(),
1.90,
'fast_run',
'scan')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论