提交 520fd3b2 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

replace toposort with schedule in Linker code

replace code like order = fgraph.toposort() with code like order = self.schedule(fgraph) in Linker code
上级 e76dbade
...@@ -1574,7 +1574,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1574,7 +1574,7 @@ class _Linker(gof.link.LocalLinker):
fgraph = self.fgraph fgraph = self.fgraph
input_storage_ = input_storage input_storage_ = input_storage
output_storage_ = output_storage output_storage_ = output_storage
#order = fgraph.toposort() #order = self.schedule(fgraph)
#Compute a topological ordering that IGNORES the destroy_map of destructive Ops. #Compute a topological ordering that IGNORES the destroy_map of destructive Ops.
#This will be OK, because every thunk is evaluated on a copy of its input. #This will be OK, because every thunk is evaluated on a copy of its input.
...@@ -1582,7 +1582,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1582,7 +1582,7 @@ class _Linker(gof.link.LocalLinker):
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(fgraph.inputs, order_outputs) order = graph.io_toposort(fgraph.inputs, order_outputs)
active_order = fgraph.toposort() # an ordering of just the active nodes active_order = self.schedule(fgraph) # an ordering of just the active nodes
active_order_set = set(active_order) active_order_set = set(active_order)
no_recycling = self.no_recycling no_recycling = self.no_recycling
......
...@@ -538,7 +538,7 @@ class ProfileMode(Mode): ...@@ -538,7 +538,7 @@ class ProfileMode(Mode):
items.sort(key=lambda a: a[1]) items.sort(key=lambda a: a[1])
items.reverse() items.reverse()
order = fgraph.toposort() order = self.linker.schedule(fgraph)
computed, last_user = gof.link.gc_helper(order) computed, last_user = gof.link.gc_helper(order)
for node in order: for node in order:
post_thunk_old_storage.append([ input_idx post_thunk_old_storage.append([ input_idx
......
...@@ -436,7 +436,7 @@ class CLinker(link.Linker): ...@@ -436,7 +436,7 @@ class CLinker(link.Linker):
self.temps = list(set(self.variables).difference( self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans)) self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = [] self.consts = []
self.node_order = fgraph.toposort() self.node_order = self.schedule(fgraph)
def code_gen(self): def code_gen(self):
"""WRITEME """WRITEME
...@@ -994,8 +994,7 @@ class CLinker(link.Linker): ...@@ -994,8 +994,7 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
@staticmethod def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None,
def cmodule_key_(fgraph, no_recycling, compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True, c_compiler=None): header_dirs=None, insert_config_md5=True, c_compiler=None):
""" """
Do the actual computation of cmodule_key in a static method Do the actual computation of cmodule_key in a static method
...@@ -1007,7 +1006,7 @@ class CLinker(link.Linker): ...@@ -1007,7 +1006,7 @@ class CLinker(link.Linker):
libraries = [] libraries = []
if header_dirs is None: if header_dirs is None:
header_dirs = [] header_dirs = []
order = list(fgraph.toposort()) order = self.schedule(fgraph)
#set of variables that have been computed by nodes we have #set of variables that have been computed by nodes we have
# seen 'so far' in the loop below # seen 'so far' in the loop below
fgraph_computed_set = set() fgraph_computed_set = set()
...@@ -1430,7 +1429,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1430,7 +1429,7 @@ class OpWiseCLinker(link.LocalLinker):
try: try:
fgraph = self.fgraph fgraph = self.fgraph
order = fgraph.toposort() order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
......
...@@ -52,6 +52,7 @@ def thunk_hook(type, value, trace): ...@@ -52,6 +52,7 @@ def thunk_hook(type, value, trace):
sys.excepthook = thunk_hook sys.excepthook = thunk_hook
# TODO: Make this work with linker defined schedule
def raise_with_op(op, exc_info=None): def raise_with_op(op, exc_info=None):
""" """
Re-raise an exception while annotating the exception object with Re-raise an exception while annotating the exception object with
...@@ -173,6 +174,8 @@ class Linker(object): ...@@ -173,6 +174,8 @@ class Linker(object):
return execute return execute
def schedule(self, fgraph):
return fgraph.toposort()
#TODO: Move this class to the compile module, where it is used (and for which it exists). #TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object): class Container(object):
...@@ -414,7 +417,7 @@ class PerformLinker(LocalLinker): ...@@ -414,7 +417,7 @@ class PerformLinker(LocalLinker):
"""WRITEME """WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{FunctionGraph.toposort}. the L{FunctionGraph} in the order given by L{Linker.schedule}.
""" """
def __init__(self, allow_gc=True): def __init__(self, allow_gc=True):
...@@ -449,7 +452,7 @@ class PerformLinker(LocalLinker): ...@@ -449,7 +452,7 @@ class PerformLinker(LocalLinker):
""" """
fgraph = self.fgraph fgraph = self.fgraph
order = list(fgraph.toposort()) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
......
...@@ -780,7 +780,7 @@ class VM_Linker(link.LocalLinker): ...@@ -780,7 +780,7 @@ class VM_Linker(link.LocalLinker):
output_storage=None, output_storage=None,
): ):
fgraph = self.fgraph fgraph = self.fgraph
order = list(fgraph.toposort()) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
......
...@@ -184,7 +184,7 @@ class Scan(PureOp): ...@@ -184,7 +184,7 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out) local_fgraph = gof.FunctionGraph(tmp_in, tmp_out)
self._cmodule_key = gof.CLinker.cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
else: else:
self._hash_inner_graph = self.info['gpu_hash'] self._hash_inner_graph = self.info['gpu_hash']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论