提交 1882012f authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3516 from carriepl/various_scan_speedups

Various scan fixes and speedups
...@@ -1284,6 +1284,34 @@ class CLinker(link.Linker): ...@@ -1284,6 +1284,34 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
def cmodule_key_variables(self, inputs, outputs, no_recycling,
compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True,
c_compiler=None):
# Assemble a dummy fgraph using the provided inputs and outputs. It is
# only used to compute the cmodule key so it only need to expose an
# `inputs` and an `outputs` attribute as well as a toposort() method
# which returns a deterministic result.
class FakeFunctionGraph():
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
def toposort(self):
# Calling io_toposort() here is fine because the results will
# only be used to compute the cmodule key which requires that
# the result of the toposort be deterministic. The ordering
# doesn't need to include information about inplace operations
# because that information will be included explicitly in
# cmodule_key_().
return graph.io_toposort(self.inputs, self.outputs)
fgraph = FakeFunctionGraph(inputs, outputs)
return self.cmodule_key_(fgraph, no_recycling, compile_args,
libraries, header_dirs, insert_config_md5,
c_compiler)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True, libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None): c_compiler=None):
...@@ -1425,8 +1453,15 @@ class CLinker(link.Linker): ...@@ -1425,8 +1453,15 @@ class CLinker(link.Linker):
fgraph_computed_set.update(node.outputs) fgraph_computed_set.update(node.outputs)
# Add not used input in the key # Add not used input in the key
# If inputs don't define a 'clients' attribute (as is the case if
# fgraph is not a real FunctionGraph but a FakeFunctionGraph, a
# lightweight class designed to imitate FunctionGraph), pretend they
# have none. This if fine because the goal is only to have all of the
# graph's information used to compute the key. If we mistakenly
# pretend that inputs with clients don't have any, were are only using
# those inputs more than once to compute the key.
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs) for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]: if not len(getattr(var, 'clients', []))]:
sig.append((var.type, in_sig(var, -1, ipos))) sig.append((var.type, in_sig(var, -1, ipos)))
# crystalize the signature and version # crystalize the signature and version
......
...@@ -221,7 +221,9 @@ class Scan(PureOp): ...@@ -221,7 +221,9 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False) local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_variables(self.inputs,
self.outputs,
[])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner # Compute mappings between outer inputs, outer outputs, inner
......
...@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -240,14 +240,11 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clean_outputs, clean_outputs)
clone=False) local_fgraph_outs_set = set(clean_outputs)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict([(v, k) for k, v in \ local_fgraph_outs_map = dict([(v, k) for k, v in \
enumerate(local_fgraph.outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -364,11 +361,9 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -364,11 +361,9 @@ class PushOutNonSeqScan(gof.Optimizer):
nw_outer.append(repl_out) nw_outer.append(repl_out)
givens[to_repl] = repl_in givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs, op_outs = scan_utils.clone(clean_outputs, replace=givens)
replace=givens) op_ins = clean_inputs + nw_inner
_op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node # Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info) nwScan = scan_op.Scan(op_ins, op_outs, op.info)
...@@ -452,12 +447,11 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -452,12 +447,11 @@ class PushOutSeqScan(gof.Optimizer):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, local_fgraph_topo = theano.gof.graph.io_toposort(clean_inputs,
clone=False) clean_outputs)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_outs_set = set(clean_outputs)
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict([(v,k) for k,v in \ local_fgraph_outs_map = dict([(v,k) for k,v in \
enumerate(local_fgraph.outputs)]) enumerate(clean_outputs)])
to_remove_set = set() to_remove_set = set()
to_replace_set = set() to_replace_set = set()
...@@ -614,10 +608,9 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -614,10 +608,9 @@ class PushOutSeqScan(gof.Optimizer):
givens[to_repl] = repl_in givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs, op_outs = scan_utils.clone(clean_outputs, replace=givens)
replace=givens) op_ins = nw_inner + clean_inputs
_op_ins = nw_inner + clean_inputs
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node # Reconstruct node
nw_info = op.info.copy() nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner) nw_info['n_seqs'] += len(nw_inner)
...@@ -640,7 +633,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -640,7 +633,7 @@ class PushOutSeqScan(gof.Optimizer):
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx] _y = replace_with_out[idx]
ls = local_fgraph.outputs ls = clean_outputs
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out) odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx] inp = op.outer_mitsot(node)[odx]
......
...@@ -121,7 +121,7 @@ class until(object): ...@@ -121,7 +121,7 @@ class until(object):
of code to fail). of code to fail).
""" """
def __init__(self, condition): def __init__(self, condition):
self.condition = tensor.as_tensor_variable(condition) self.condition = tensor.as_tensor_variable(condition)
assert self.condition.ndim == 0 assert self.condition.ndim == 0
...@@ -131,7 +131,7 @@ def traverse(out, x, x_copy, d, visited=None): ...@@ -131,7 +131,7 @@ def traverse(out, x, x_copy, d, visited=None):
""" """
Function used by scan to parse the tree and figure out which nodes Function used by scan to parse the tree and figure out which nodes
it needs to replace. it needs to replace.
There are two options : There are two options :
1) x and x_copy or on host, then you would replace x with x_copy 1) x and x_copy or on host, then you would replace x with x_copy
2) x is on gpu, x_copy on host, then you need to replace 2) x is on gpu, x_copy on host, then you need to replace
...@@ -201,7 +201,7 @@ def clone(output, ...@@ -201,7 +201,7 @@ def clone(output,
copy_inputs=DEPRECATED_ARG): copy_inputs=DEPRECATED_ARG):
""" """
Function that allows replacing subgraphs of a computational graph. Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding It returns a copy of the initial subgraph with the corresponding
substitutions. substitutions.
...@@ -1308,7 +1308,7 @@ def forced_replace(out, x, y): ...@@ -1308,7 +1308,7 @@ def forced_replace(out, x, y):
Check all internal values of the graph that compute the variable ``out`` Check all internal values of the graph that compute the variable ``out``
for occurrences of values identical with ``x``. If such occurrences are for occurrences of values identical with ``x``. If such occurrences are
encountered then they are replaced with variable ``y``. encountered then they are replaced with variable ``y``.
Parameters Parameters
---------- ----------
out : Theano Variable out : Theano Variable
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论