提交 b0050f2d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new optimization that merges scan ops

上级 ecc69291
......@@ -26,7 +26,7 @@ from theano import config
import scan_op
import scan_utils
from scan_utils import clone, equal_computations
from scan_utils import clone, equal_computations, find_up, scan_args
from theano.gof.opt import pre_constant_merge, pre_greedy_local_optimizer
# Logging function for sending warning or info
......@@ -757,6 +757,92 @@ optdb.register( 'scanOp_save_mem'
, 'scan')
class ScanMerge(gof.Optimizer):
""" Graph Optimizer that merges different scan ops """
def add_requirements(self,env):
env.extend(gof.toolbox.ReplaceValidate())
def merge(self, A,B, as_while):
Aargs = scan_args(A.inputs, A.outputs, A.op.inputs, A.op.outputs, A.op.info)
Bargs = scan_args(B.inputs, B.outputs, B.op.inputs, B.op.outputs, B.op.info)
Margs = Aargs.merge(Bargs)
# fixup name
info = Margs.info
info['name'] = A.op.name+'&'+B.op.name
if as_while:
Margs_inner_outs = Margs.inner_outputs + Margs.cond
else:
Margs_inner_outs = Margs.inner_outputs
op = scan_op.Scan(Margs.inner_inputs, Margs_inner_outs, info)
outputs = op(*Margs.outer_inputs)
if type(outputs) not in (list, tuple):
outputs = [outputs]
return zip(Margs.outer_outputs, outputs)
def apply(self, env):
nodelist = list(env.toposort())
scan_nodes = filter(lambda s: isinstance(s.op, scan_op.Scan), nodelist)
nscan = dict()
for snode in scan_nodes:
n_steps = snode.inputs[0]
try:
n_steps = int(get_constant_value(n_steps))
except TypeError:
pass
l = nscan.get(n_steps)
if l is None:
nscan[n_steps] = [snode]
else:
l.append(snode)
for snodes in nscan.values():
if len(snodes) > 1:
# amongst nodes that have the same number of steps
# try to find the ones that can be merged
curnode = snodes[0]
for snode in snodes[1:]:
if (snode.op.truncate_gradient == curnode.op.truncate_gradient and
snode.op.mode == curnode.op.mode and
not find_up(snode, curnode)):
if (not snode.op.as_while and
not curnode.op.as_while):
proposal = self.merge(curnode, snode, False)
env.replace_all_validate(proposal, reason='scan merge')
elif (snode.op.as_while and
curnode.op.as_while):
# check if equal computations
correspondance = dict(zip(snode.op.inputs,
curnode.op.inputs))
if scan_utils.equal_computations(
[snode.op.outputs[-1]],
[curnode.op.outputs[-1]],
snode.op.inputs,
curnode.op.inputs):
proposal = self.merge(curnode, snode, True)
env.replace_all_validate(proposal, reason =
'scan_merge')
else:
pass
else:
pass
# other merges will be done in other passes
break
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
optdb.register('scanOp_merge',
EquilibriumOptimizer([ScanMerge()],
max_use_ratio=11),
1.90,
'fast_run',
'scan')
from theano.sandbox import cuda
if cuda.cuda_available:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论