提交 1882012f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3516 from carriepl/various_scan_speedups

Various scan fixes and speedups
...@@ -1284,6 +1284,34 @@ class CLinker(link.Linker): ...@@ -1284,6 +1284,34 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
def cmodule_key_variables(self, inputs, outputs, no_recycling,
compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True,
c_compiler=None):
# Assemble a dummy fgraph using the provided inputs and outputs. It is
# only used to compute the cmodule key so it only need to expose an
# `inputs` and an `outputs` attribute as well as a toposort() method
# which returns a deterministic result.
class FakeFunctionGraph():
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
def toposort(self):
# Calling io_toposort() here is fine because the results will
# only be used to compute the cmodule key which requires that
# the result of the toposort be deterministic. The ordering
# doesn't need to include information about inplace operations
# because that information will be included explicitly in
# cmodule_key_().
return graph.io_toposort(self.inputs, self.outputs)
fgraph = FakeFunctionGraph(inputs, outputs)
return self.cmodule_key_(fgraph, no_recycling, compile_args,
libraries, header_dirs, insert_config_md5,
c_compiler)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True, libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None): c_compiler=None):
...@@ -1425,8 +1453,15 @@ class CLinker(link.Linker): ...@@ -1425,8 +1453,15 @@ class CLinker(link.Linker):
fgraph_computed_set.update(node.outputs) fgraph_computed_set.update(node.outputs)
# Add not used input in the key # Add not used input in the key
# If inputs don't define a 'clients' attribute (as is the case if
# fgraph is not a real FunctionGraph but a FakeFunctionGraph, a
# lightweight class designed to imitate FunctionGraph), pretend they
# have none. This if fine because the goal is only to have all of the
# graph's information used to compute the key. If we mistakenly
# pretend that inputs with clients don't have any, were are only using
# those inputs more than once to compute the key.
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs) for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]: if not len(getattr(var, 'clients', []))]:
sig.append((var.type, in_sig(var, -1, ipos))) sig.append((var.type, in_sig(var, -1, ipos)))
# crystalize the signature and version # crystalize the signature and version
......
...@@ -221,7 +221,9 @@ class Scan(PureOp): ...@@ -221,7 +221,9 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False) local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_variables(self.inputs,
self.outputs,
[])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner # Compute mappings between outer inputs, outer outputs, inner
......
...@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs, clean_outputs)
clone=False) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in \
enumerate(local_fgraph.outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -364,11 +361,9 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -364,11 +361,9 @@ class PushOutNonSeqScan(gof.Optimizer):
nw_outer.append(repl_out) nw_outer.append(repl_out)
givens[to_repl] = repl_in givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs, op_outs = scan_utils.clone(clean_outputs, replace=givens)
replace=givens) op_ins = clean_inputs + nw_inner
_op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node # Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info) nwScan = scan_op.Scan(op_ins, op_outs, op.info)
...@@ -452,12 +447,11 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -452,12 +447,11 @@ class PushOutSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clone=False) clean_outputs)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v,k) for k,v in \
enumerate(local_fgraph.outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -614,10 +608,9 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -614,10 +608,9 @@ class PushOutSeqScan(gof.Optimizer):
givens[to_repl] = repl_in givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs, op_outs = scan_utils.clone(clean_outputs, replace=givens)
replace=givens) op_ins = nw_inner + clean_inputs
_op_ins = nw_inner + clean_inputs
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node # Reconstruct node
nw_info = op.info.copy() nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner) nw_info['n_seqs'] += len(nw_inner)
...@@ -640,7 +633,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -640,7 +633,7 @@ class PushOutSeqScan(gof.Optimizer):
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx] _y = replace_with_out[idx]
ls = local_fgraph.outputs ls = clean_outputs
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out) odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx] inp = op.outer_mitsot(node)[odx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论