提交 89863001 authored 作者: nouiz's avatar nouiz

Merge pull request #931 from mrocklin/schedule

Schedule
......@@ -1547,10 +1547,12 @@ default_make_thunk = [theano.gof.Op.make_thunk.im_func,
class _Linker(gof.link.LocalLinker):
"""Special debugging linker"""
def __init__(self, maker):
def __init__(self, maker, schedule=None):
super(gof.LocalLinker, self).__init__()
self.fgraph = None
self.maker = maker
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
......@@ -1574,7 +1576,7 @@ class _Linker(gof.link.LocalLinker):
fgraph = self.fgraph
input_storage_ = input_storage
output_storage_ = output_storage
#order = fgraph.toposort()
#order = self.schedule(fgraph)
#Compute a topological ordering that IGNORES the destroy_map of destructive Ops.
#This will be OK, because every thunk is evaluated on a copy of its input.
......@@ -1582,7 +1584,7 @@ class _Linker(gof.link.LocalLinker):
order_outputs.reverse()
order = graph.io_toposort(fgraph.inputs, order_outputs)
active_order = fgraph.toposort() # an ordering of just the active nodes
active_order = self.schedule(fgraph) # an ordering of just the active nodes
active_order_set = set(active_order)
no_recycling = self.no_recycling
......
......@@ -538,7 +538,7 @@ class ProfileMode(Mode):
items.sort(key=lambda a: a[1])
items.reverse()
order = fgraph.toposort()
order = self.linker.schedule(fgraph)
computed, last_user = gof.link.gc_helper(order)
for node in order:
post_thunk_old_storage.append([ input_idx
......
......@@ -402,8 +402,10 @@ class CLinker(link.Linker):
associated to it during the computation (to avoid reusing it).
"""
def __init__(self):
def __init__(self, schedule=None):
self.fgraph = None
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
"""WRITEME"""
......@@ -436,7 +438,7 @@ class CLinker(link.Linker):
self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
self.node_order = fgraph.toposort()
self.node_order = self.schedule(fgraph)
def code_gen(self):
"""WRITEME
......@@ -994,8 +996,7 @@ class CLinker(link.Linker):
c_compiler=self.c_compiler(),
)
@staticmethod
def cmodule_key_(fgraph, no_recycling, compile_args=None, libraries=None,
def cmodule_key_(self, fgraph, no_recycling, compile_args=None, libraries=None,
header_dirs=None, insert_config_md5=True, c_compiler=None):
"""
Do the actual computation of cmodule_key in a static method
......@@ -1007,7 +1008,7 @@ class CLinker(link.Linker):
libraries = []
if header_dirs is None:
header_dirs = []
order = list(fgraph.toposort())
order = self.schedule(fgraph)
#set of variables that have been computed by nodes we have
# seen 'so far' in the loop below
fgraph_computed_set = set()
......@@ -1397,13 +1398,16 @@ class OpWiseCLinker(link.LocalLinker):
def __init__(self,
fallback_on_perform=True,
allow_gc=None,
nice_errors=True):
nice_errors=True,
schedule=None):
if allow_gc is None:
allow_gc = config.allow_gc
self.fgraph = None
self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
self.allow_gc = allow_gc
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
......@@ -1430,7 +1434,7 @@ class OpWiseCLinker(link.LocalLinker):
try:
fgraph = self.fgraph
order = fgraph.toposort()
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(
......@@ -1523,7 +1527,7 @@ class DualLinker(link.Linker):
function.
"""
def __init__(self, checker=_default_checker):
def __init__(self, checker=_default_checker, schedule=None):
"""
Initialize a DualLinker.
......@@ -1548,6 +1552,8 @@ class DualLinker(link.Linker):
"""
self.fgraph = None
self.checker = checker
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
if no_recycling is None:
......@@ -1565,11 +1571,13 @@ class DualLinker(link.Linker):
fgraph = self.fgraph
no_recycling = self.no_recycling
_f, i1, o1, thunks1, order1 = link.PerformLinker().accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs)
_f, i1, o1, thunks1, order1 = (
link.PerformLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
kwargs.pop('input_storage', None)
_f, i2, o2, thunks2, order2 = OpWiseCLinker().accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs)
_f, i2, o2, thunks2, order2 = (
OpWiseCLinker(schedule=self.schedule).accept(fgraph,
no_recycling=no_recycling).make_all(**kwargs))
def f():
for input1, input2 in izip(i1, i2):
......
......@@ -52,6 +52,7 @@ def thunk_hook(type, value, trace):
sys.excepthook = thunk_hook
# TODO: Make this work with linker defined schedule
def raise_with_op(op, exc_info=None):
"""
Re-raise an exception while annotating the exception object with
......@@ -173,6 +174,8 @@ class Linker(object):
return execute
def schedule(self, fgraph):
return fgraph.toposort()
#TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object):
......@@ -414,13 +417,15 @@ class PerformLinker(LocalLinker):
"""WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{FunctionGraph.toposort}.
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def __init__(self, allow_gc=True):
def __init__(self, allow_gc=True, schedule=None):
#TODO: set allow_gc = True by default, when it works with the OpWiseCLinker
self.fgraph = None
self.allow_gc = allow_gc
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
"""
......@@ -449,7 +454,7 @@ class PerformLinker(LocalLinker):
"""
fgraph = self.fgraph
order = list(fgraph.toposort())
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
......
......@@ -155,3 +155,15 @@ def sort_apply_nodes(inputs, outputs, cmps):
"""
return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators
See also:
sort_apply_nodes
"""
cmps = (dependence,) + cmps
def schedule(fgraph):
""" Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule
......@@ -163,3 +163,18 @@ class TestWrapLinker(unittest.TestCase):
fn()
assert nodes == [div, add, mul]
assert o[0].data == 1.5
def test_sort_schedule_fn():
import theano
from theano.gof.sched import sort_schedule_fn, depends
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x[:5]*2, x.T+1).T
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
mode = theano.Mode(linker=linker)
f = theano.function((x,), (y,), mode=mode)
nodes = f.maker.linker.make_all()[-1]
for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b,a)):
assert str(a) < str(b)
......@@ -508,7 +508,7 @@ class VM_Linker(link.LocalLinker):
"""
def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None):
lazy=None, schedule=None):
"""
allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on
......@@ -538,6 +538,8 @@ class VM_Linker(link.LocalLinker):
self.callback = callback
self.lazy = lazy
self.updated_vars = {}
if schedule:
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
"""
......@@ -559,7 +561,8 @@ class VM_Linker(link.LocalLinker):
allow_gc=self.allow_gc,
use_cloop=self.use_cloop,
callback=self.callback,
lazy=self.lazy
lazy=self.lazy,
schedule=self.schedule
).accept(fgraph, no_recycling)
self.fgraph = fgraph
self.no_recycling = no_recycling
......@@ -780,7 +783,7 @@ class VM_Linker(link.LocalLinker):
output_storage=None,
):
fgraph = self.fgraph
order = list(fgraph.toposort())
order = self.schedule(fgraph)
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(
......
......@@ -184,7 +184,7 @@ class Scan(PureOp):
tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out)
self._cmodule_key = gof.CLinker.cmodule_key_(local_fgraph, [])
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key)
else:
self._hash_inner_graph = self.info['gpu_hash']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论