提交 6347cfc5 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimization for OpWiseCLinker

上级 edad54d2
...@@ -140,7 +140,7 @@ class DimShuffle(Op): ...@@ -140,7 +140,7 @@ class DimShuffle(Op):
and self.new_order == other.new_order \ and self.new_order == other.new_order \
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def __hash__(self, other): def __hash__(self):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable) return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self): def __str__(self):
...@@ -276,7 +276,7 @@ class Elemwise(Op): ...@@ -276,7 +276,7 @@ class Elemwise(Op):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern
def __hash__(self): def __hash__(self):
return hash(self.scalar_op) ^ hash(self.inplace_pattern) return hash(self.scalar_op) ^ hash(tuple(self.inplace_pattern.items()))
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
......
...@@ -530,18 +530,6 @@ class CLinker(link.Linker): ...@@ -530,18 +530,6 @@ class CLinker(link.Linker):
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i] self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
return self.struct_code return self.struct_code
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
else:
return self.tasks[failure_code - n]
def support_code(self): def support_code(self):
""" """
Returns a list of support code strings that are needed by Returns a list of support code strings that are needed by
...@@ -653,19 +641,7 @@ class CLinker(link.Linker): ...@@ -653,19 +641,7 @@ class CLinker(link.Linker):
first_output = ostor[0].data first_output = ostor[0].data
""" """
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage) cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
def execute(): return _execute(cthunk, self.init_tasks, self.tasks), in_storage, out_storage
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = self.find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace
return execute, in_storage, out_storage
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
""" """
...@@ -762,6 +738,33 @@ class CLinker(link.Linker): ...@@ -762,6 +738,33 @@ class CLinker(link.Linker):
return ret return ret
def _execute(cthunk, init_tasks, tasks):
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [init_tasks, tasks][failure_code % 2][failure_code/2]
else:
return tasks[failure_code - n]
def execute():
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace
return execute
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
""" """
...@@ -779,6 +782,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -779,6 +782,8 @@ class OpWiseCLinker(link.LocalLinker):
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
__cache__ = {}
def __init__(self, fallback_on_perform = True): def __init__(self, fallback_on_perform = True):
self.env = None self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
...@@ -805,7 +810,20 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -805,7 +810,20 @@ class OpWiseCLinker(link.LocalLinker):
try: try:
e = Env(*graph.clone(node.inputs, node.outputs)) e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes e.toposort = lambda: e.nodes
if any(isinstance(input, graph.Value) for input in node.inputs):
desc = None
else:
desc = (node.op,
tuple(input.type for input in node.inputs),
tuple(input.type for input in node.inputs),
tuple(output in no_recycling for output in node.outputs))
if desc in self.__cache__:
cl = self.__cache__[desc]
else:
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling]) cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if desc is not None:
self.__cache__[desc] = cl
thunk, node_input_filters, node_output_filters = cl.make_thunk( thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage, input_storage = node_input_storage,
output_storage = node_output_storage) output_storage = node_output_storage)
...@@ -829,7 +847,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -829,7 +847,7 @@ class OpWiseCLinker(link.LocalLinker):
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = link.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [link.Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
......
...@@ -164,13 +164,7 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -164,13 +164,7 @@ def map_storage(env, order, input_storage, output_storage):
return input_storage, output_storage, storage_map return input_storage, output_storage, storage_map
def streamline(env, thunks, order, no_recycling = [], profiler = None):
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None: if profiler is None:
def f(): def f():
for x in no_recycling: for x in no_recycling:
...@@ -191,6 +185,12 @@ class LocalLinker(Linker): ...@@ -191,6 +185,12 @@ class LocalLinker(Linker):
f.profiler = profiler f.profiler = profiler
return f return f
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def make_thunk(self, profiler = None, input_storage = None, output_storage = None): def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler, return self.make_all(profiler = profiler,
input_storage = input_storage, input_storage = input_storage,
...@@ -248,7 +248,7 @@ class PerformLinker(LocalLinker): ...@@ -248,7 +248,7 @@ class PerformLinker(LocalLinker):
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
......
...@@ -323,10 +323,12 @@ class ScalarOp(Op): ...@@ -323,10 +323,12 @@ class ScalarOp(Op):
raise AbstractFunctionError() raise AbstractFunctionError()
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.output_types_preference == other.output_types_preference return type(self) == type(other) \
and getattr(self, 'output_types_preference', None) \
== getattr(other, 'output_types_preference', None)
def __hash__(self): def __hash__(self):
return hash(self.output_types_preference) return hash(getattr(self, 'output_types_preference', 0))
def __str__(self): def __str__(self):
if hasattr(self, 'name') and self.name: if hasattr(self, 'name') and self.name:
...@@ -805,3 +807,9 @@ class Composite(ScalarOp): ...@@ -805,3 +807,9 @@ class Composite(ScalarOp):
**sub) **sub)
d['name'] = name d['name'] = name
return self._c_code % d return self._c_code % d
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论