提交 89863001 authored 作者: nouiz's avatar nouiz

Merge pull request #931 from mrocklin/schedule

Schedule
...@@ -1547,10 +1547,12 @@ default_make_thunk = [theano.gof.Op.make_thunk.im_func, ...@@ -1547,10 +1547,12 @@ default_make_thunk = [theano.gof.Op.make_thunk.im_func,
class _Linker(gof.link.LocalLinker): class _Linker(gof.link.LocalLinker):
"""Special debugging linker""" """Special debugging linker"""
def __init__(self, maker): def __init__(self, maker, schedule=None):
super(gof.LocalLinker, self).__init__() super(gof.LocalLinker, self).__init__()
self.fgraph = None self.fgraph = None
self.maker = maker self.maker = maker
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
...@@ -1574,7 +1576,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1574,7 +1576,7 @@ class _Linker(gof.link.LocalLinker):
fgraph = self.fgraph fgraph = self.fgraph
input_storage_ = input_storage input_storage_ = input_storage
output_storage_ = output_storage output_storage_ = output_storage
#order = fgraph.toposort() #order = self.schedule(fgraph)
#Compute a topological ordering that IGNORES the destroy_map of destructive Ops. #Compute a topological ordering that IGNORES the destroy_map of destructive Ops.
#This will be OK, because every thunk is evaluated on a copy of its input. #This will be OK, because every thunk is evaluated on a copy of its input.
...@@ -1582,7 +1584,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1582,7 +1584,7 @@ class _Linker(gof.link.LocalLinker):
order_outputs.reverse() order_outputs.reverse()
order = graph.io_toposort(fgraph.inputs, order_outputs) order = graph.io_toposort(fgraph.inputs, order_outputs)
active_order = fgraph.toposort() # an ordering of just the active nodes active_order = self.schedule(fgraph) # an ordering of just the active nodes
active_order_set = set(active_order) active_order_set = set(active_order)
no_recycling = self.no_recycling no_recycling = self.no_recycling
......
...@@ -538,7 +538,7 @@ class ProfileMode(Mode): ...@@ -538,7 +538,7 @@ class ProfileMode(Mode):
items.sort(key=lambda a: a[1]) items.sort(key=lambda a: a[1])
items.reverse() items.reverse()
order = fgraph.toposort() order = self.linker.schedule(fgraph)
computed, last_user = gof.link.gc_helper(order) computed, last_user = gof.link.gc_helper(order)
for node in order: for node in order:
post_thunk_old_storage.append([ input_idx post_thunk_old_storage.append([ input_idx
......
...@@ -402,8 +402,10 @@ class CLinker(link.Linker): ...@@ -402,8 +402,10 @@ class CLinker(link.Linker):
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
def __init__(self): def __init__(self, schedule=None):
self.fgraph = None self.fgraph = None
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
"""WRITEME""" """WRITEME"""
...@@ -436,7 +438,7 @@ class CLinker(link.Linker): ...@@ -436,7 +438,7 @@ class CLinker(link.Linker):
self.temps = list(set(self.variables).difference( self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans)) self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = [] self.consts = []
self.node_order = fgraph.toposort() self.node_order = self.schedule(fgraph)
def code_gen(self): def code_gen(self):
"""WRITEME """WRITEME
...@@ -994,8 +996,7 @@ class CLinker(link.Linker): ...@@ -994,8 +996,7 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(), c_compiler=self.c_compiler(),
) )
@staticmethod def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None,
def cmodule_key_(fgraph, no_recycling, compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True, c_compiler=None): header_dirs=None, insert_config_md5=True, c_compiler=None):
""" """
Do the actual computation of cmodule_key in a static method Do the actual computation of cmodule_key in a static method
...@@ -1007,7 +1008,7 @@ class CLinker(link.Linker): ...@@ -1007,7 +1008,7 @@ class CLinker(link.Linker):
libraries = [] libraries = []
if header_dirs is None: if header_dirs is None:
header_dirs = [] header_dirs = []
order = list(fgraph.toposort()) order = self.schedule(fgraph)
#set of variables that have been computed by nodes we have #set of variables that have been computed by nodes we have
# seen 'so far' in the loop below # seen 'so far' in the loop below
fgraph_computed_set = set() fgraph_computed_set = set()
...@@ -1397,13 +1398,16 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1397,13 +1398,16 @@ class OpWiseCLinker(link.LocalLinker):
def __init__(self, def __init__(self,
fallback_on_perform=True, fallback_on_perform=True,
allow_gc=None, allow_gc=None,
nice_errors=True): nice_errors=True,
schedule=None):
if allow_gc is None: if allow_gc is None:
allow_gc = config.allow_gc allow_gc = config.allow_gc
self.fgraph = None self.fgraph = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors self.nice_errors = nice_errors
self.allow_gc = allow_gc self.allow_gc = allow_gc
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
...@@ -1430,7 +1434,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1430,7 +1434,7 @@ class OpWiseCLinker(link.LocalLinker):
try: try:
fgraph = self.fgraph fgraph = self.fgraph
order = fgraph.toposort() order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
...@@ -1523,7 +1527,7 @@ class DualLinker(link.Linker): ...@@ -1523,7 +1527,7 @@ class DualLinker(link.Linker):
function. function.
""" """
def __init__(self, checker=_default_checker): def __init__(self, checker=_default_checker, schedule=None):
""" """
Initialize a DualLinker. Initialize a DualLinker.
...@@ -1548,6 +1552,8 @@ class DualLinker(link.Linker): ...@@ -1548,6 +1552,8 @@ class DualLinker(link.Linker):
""" """
self.fgraph = None self.fgraph = None
self.checker = checker self.checker = checker
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
if no_recycling is None: if no_recycling is None:
...@@ -1565,11 +1571,13 @@ class DualLinker(link.Linker): ...@@ -1565,11 +1571,13 @@ class DualLinker(link.Linker):
fgraph = self.fgraph fgraph = self.fgraph
no_recycling = self.no_recycling no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = link.PerformLinker().accept(fgraph, _f, i1, o1, thunks1, order1 = (
no_recycling=no_recycling).make_all(**kwargs) link.PerformLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None) kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = OpWiseCLinker().accept(fgraph, _f, i2, o2, thunks2, order2 = (
no_recycling=no_recycling).make_all(**kwargs) OpWiseCLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
def f(): def f():
for input1, input2 in izip(i1, i2): for input1, input2 in izip(i1, i2):
......
...@@ -52,6 +52,7 @@ def thunk_hook(type, value, trace): ...@@ -52,6 +52,7 @@ def thunk_hook(type, value, trace):
sys.excepthook = thunk_hook sys.excepthook = thunk_hook
# TODO: Make this work with linker defined schedule
def raise_with_op(op, exc_info=None): def raise_with_op(op, exc_info=None):
""" """
Re-raise an exception while annotating the exception object with Re-raise an exception while annotating the exception object with
...@@ -173,6 +174,8 @@ class Linker(object): ...@@ -173,6 +174,8 @@ class Linker(object):
return execute return execute
def schedule(self, fgraph):
return fgraph.toposort()
#TODO: Move this class to the compile module, where it is used (and for which it exists). #TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object): class Container(object):
...@@ -414,13 +417,15 @@ class PerformLinker(LocalLinker): ...@@ -414,13 +417,15 @@ class PerformLinker(LocalLinker):
"""WRITEME """WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{FunctionGraph.toposort}. the L{FunctionGraph} in the order given by L{Linker.schedule}.
""" """
def __init__(self, allow_gc=True): def __init__(self, allow_gc=True, schedule=None):
#TODO: set allow_gc = True by default, when it works with the OpWiseCLinker #TODO: set allow_gc = True by default, when it works with the OpWiseCLinker
self.fgraph = None self.fgraph = None
self.allow_gc = allow_gc self.allow_gc = allow_gc
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
...@@ -449,7 +454,7 @@ class PerformLinker(LocalLinker): ...@@ -449,7 +454,7 @@ class PerformLinker(LocalLinker):
""" """
fgraph = self.fgraph fgraph = self.fgraph
order = list(fgraph.toposort()) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
......
...@@ -155,3 +155,15 @@ def sort_apply_nodes(inputs, outputs, cmps): ...@@ -155,3 +155,15 @@ def sort_apply_nodes(inputs, outputs, cmps):
""" """
return posort(list_of_nodes(inputs, outputs), *cmps) return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators
See also:
sort_apply_nodes
"""
cmps = (dependence,) + cmps
def schedule(fgraph):
""" Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule
...@@ -163,3 +163,18 @@ class TestWrapLinker(unittest.TestCase): ...@@ -163,3 +163,18 @@ class TestWrapLinker(unittest.TestCase):
fn() fn()
assert nodes == [div, add, mul] assert nodes == [div, add, mul]
assert o[0].data == 1.5 assert o[0].data == 1.5
def test_sort_schedule_fn():
import theano
from theano.gof.sched import sort_schedule_fn, depends
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x[:5]*2, x.T+1).T
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
mode = theano.Mode(linker=linker)
f = theano.function((x,), (y,), mode=mode)
nodes = f.maker.linker.make_all()[-1]
for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b,a)):
assert str(a) < str(b)
...@@ -508,7 +508,7 @@ class VM_Linker(link.LocalLinker): ...@@ -508,7 +508,7 @@ class VM_Linker(link.LocalLinker):
""" """
def __init__(self, allow_gc=None, use_cloop=False, callback=None, def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None): lazy=None, schedule=None):
""" """
allow_gc - force the virtual machine to clean up unnecessary allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on references, in order to allow garbage collection on
...@@ -538,6 +538,8 @@ class VM_Linker(link.LocalLinker): ...@@ -538,6 +538,8 @@ class VM_Linker(link.LocalLinker):
self.callback = callback self.callback = callback
self.lazy = lazy self.lazy = lazy
self.updated_vars = {} self.updated_vars = {}
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
...@@ -559,7 +561,8 @@ class VM_Linker(link.LocalLinker): ...@@ -559,7 +561,8 @@ class VM_Linker(link.LocalLinker):
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
use_cloop=self.use_cloop, use_cloop=self.use_cloop,
callback=self.callback, callback=self.callback,
lazy=self.lazy lazy=self.lazy,
schedule=self.schedule
).accept(fgraph, no_recycling) ).accept(fgraph, no_recycling)
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
...@@ -780,7 +783,7 @@ class VM_Linker(link.LocalLinker): ...@@ -780,7 +783,7 @@ class VM_Linker(link.LocalLinker):
output_storage=None, output_storage=None,
): ):
fgraph = self.fgraph fgraph = self.fgraph
order = list(fgraph.toposort()) order = self.schedule(fgraph)
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage( input_storage, output_storage, storage_map = link.map_storage(
......
...@@ -184,7 +184,7 @@ class Scan(PureOp): ...@@ -184,7 +184,7 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs, tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out) local_fgraph = gof.FunctionGraph(tmp_in, tmp_out)
self._cmodule_key = gof.CLinker.cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
else: else:
self._hash_inner_graph = self.info['gpu_hash'] self._hash_inner_graph = self.info['gpu_hash']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论