提交 1882012f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3516 from carriepl/various_scan_speedups

Various scan fixes and speedups
......@@ -1284,6 +1284,34 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(),
)
def cmodule_key_variables(self, inputs, outputs, no_recycling,
compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True,
c_compiler=None):
# Assemble a dummy fgraph using the provided inputs and outputs. It is
# only used to compute the cmodule key so it only need to expose an
# `inputs` and an `outputs` attribute as well as a toposort() method
# which returns a deterministic result.
class FakeFunctionGraph():
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
def toposort(self):
# Calling io_toposort() here is fine because the results will
# only be used to compute the cmodule key which requires that
# the result of the toposort be deterministic. The ordering
# doesn't need to include information about inplace operations
# because that information will be included explicitly in
# cmodule_key_().
return graph.io_toposort(self.inputs, self.outputs)
fgraph = FakeFunctionGraph(inputs, outputs)
return self.cmodule_key_(fgraph, no_recycling, compile_args,
libraries, header_dirs, insert_config_md5,
c_compiler)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
......@@ -1425,8 +1453,15 @@ class CLinker(link.Linker):
fgraph_computed_set.update(node.outputs)
# Add not used input in the key
# If inputs don't define a 'clients' attribute (as is the case if
# fgraph is not a real FunctionGraph but a FakeFunctionGraph, a
# lightweight class designed to imitate FunctionGraph), pretend they
# have none. This if fine because the goal is only to have all of the
# graph's information used to compute the key. If we mistakenly
# pretend that inputs with clients don't have any, were are only using
# those inputs more than once to compute the key.
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
if not len(getattr(var, 'clients', []))]:
sig.append((var.type, in_sig(var, -1, ipos)))
# crystalize the signature and version
......
......@@ -221,7 +221,9 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._cmodule_key = gof.CLinker().cmodule_key_variables(self.inputs,
self.outputs,
[])
self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner
......
......@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs,
clean_outputs,
clone=False)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \
enumerate(local_fgraph.outputs)])
enumerate(clean_outputs)])
to_remove_set = set()
to_replace_set = set()
......@@ -364,11 +361,9 @@ class PushOutNonSeqScan(gof.Optimizer):
nw_outer.append(repl_out)
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
op_outs = scan_utils.clone(clean_outputs, replace=givens)
op_ins = clean_inputs + nw_inner
_op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info)
......@@ -452,12 +447,11 @@ class PushOutSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs,
clone=False)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs)
local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \
enumerate(local_fgraph.outputs)])
enumerate(clean_outputs)])
to_remove_set = set()
to_replace_set = set()
......@@ -614,10 +608,9 @@ class PushOutSeqScan(gof.Optimizer):
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = nw_inner + clean_inputs
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
op_outs = scan_utils.clone(clean_outputs, replace=givens)
op_ins = nw_inner + clean_inputs
# Reconstruct node
nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner)
......@@ -640,7 +633,7 @@ class PushOutSeqScan(gof.Optimizer):
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx]
ls = local_fgraph.outputs
ls = clean_outputs
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论