提交 89afb152 authored 作者: carriepl's avatar carriepl

Don't build FunctionGraph to compute cmodule_key

上级 7d749f70
......@@ -1284,6 +1284,34 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(),
)
def cmodule_key_fgraph(self, inputs, outputs, no_recycling,
compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True,
c_compiler=None):
# Assemble a dummy fgraph using the provided inputs and outputs. It is
# only used to compute the cmodule key so it only need to expose an
# `inputs` and an `outputs` attribute as well as a toposort() method
# which returns a deterministic result.
class FakeFunctionGraph():
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
def toposort(self):
# Calling io_toposort() here is fine because the results will
# only be used to compute the cmodule key which requires that
# the result of the toposort be deterministic. The ordering
# doesn't need to include information about inplace operations
# because that information will be included explicitly in
# cmodule_key_().
return graph.io_toposort(self.inputs, self.outputs)
fgraph = FakeFunctionGraph(inputs, outputs)
return self.cmodule_key_(fgraph, no_recycling, compile_args,
libraries, header_dirs, insert_config_md5,
c_compiler)
def cmodule_key_(self, fgraph, no_recycling, compile_args=None,
libraries=None, header_dirs=None, insert_config_md5=True,
c_compiler=None):
......@@ -1425,8 +1453,15 @@ class CLinker(link.Linker):
fgraph_computed_set.update(node.outputs)
# Add not used input in the key
# If inputs don't define a 'clients' attribute (as is the case if
# fgraph is not a real FunctionGraph but a FakeFunctionGraph, a
# lightweight class designed to imitate FunctionGraph), pretend they
# have none. This if fine because the goal is only to have all of the
# graph's information used to compute the key. If we mistakenly
# pretend that inputs with clients don't have any, were are only using
# those inputs more than once to compute the key.
for ipos, var in [(i, var) for i, var in enumerate(fgraph.inputs)
if not len(var.clients)]:
if not len(getattr(var, 'clients', []))]:
sig.append((var.type, in_sig(var, -1, ipos)))
# crystalize the signature and version
......
......@@ -221,7 +221,10 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out, clone=False)
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._cmodule_key = gof.CLinker().cmodule_key_fgraph(self.inputs,
self.outputs,
[])
#self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论