提交 6347cfc5 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimization for OpWiseCLinker

上级 edad54d2
...@@ -140,7 +140,7 @@ class DimShuffle(Op): ...@@ -140,7 +140,7 @@ class DimShuffle(Op):
and self.new_order == other.new_order \ and self.new_order == other.new_order \
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def __hash__(self, other): def __hash__(self):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable) return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self): def __str__(self):
...@@ -276,7 +276,7 @@ class Elemwise(Op): ...@@ -276,7 +276,7 @@ class Elemwise(Op):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern
def __hash__(self): def __hash__(self):
return hash(self.scalar_op) ^ hash(self.inplace_pattern) return hash(self.scalar_op) ^ hash(tuple(self.inplace_pattern.items()))
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
......
...@@ -529,18 +529,6 @@ class CLinker(link.Linker): ...@@ -529,18 +529,6 @@ class CLinker(link.Linker):
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i] self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
return self.struct_code return self.struct_code
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
else:
return self.tasks[failure_code - n]
def support_code(self): def support_code(self):
""" """
...@@ -653,19 +641,7 @@ class CLinker(link.Linker): ...@@ -653,19 +641,7 @@ class CLinker(link.Linker):
first_output = ostor[0].data first_output = ostor[0].data
""" """
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage) cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
def execute(): return _execute(cthunk, self.init_tasks, self.tasks), in_storage, out_storage
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = self.find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace
return execute, in_storage, out_storage
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
""" """
...@@ -762,6 +738,33 @@ class CLinker(link.Linker): ...@@ -762,6 +738,33 @@ class CLinker(link.Linker):
return ret return ret
def _execute(cthunk, init_tasks, tasks):
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [init_tasks, tasks][failure_code % 2][failure_code/2]
else:
return tasks[failure_code - n]
def execute():
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace
return execute
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
""" """
...@@ -779,6 +782,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -779,6 +782,8 @@ class OpWiseCLinker(link.LocalLinker):
associated to it during the computation (to avoid reusing it). associated to it during the computation (to avoid reusing it).
""" """
__cache__ = {}
def __init__(self, fallback_on_perform = True): def __init__(self, fallback_on_perform = True):
self.env = None self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
...@@ -805,7 +810,20 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -805,7 +810,20 @@ class OpWiseCLinker(link.LocalLinker):
try: try:
e = Env(*graph.clone(node.inputs, node.outputs)) e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes e.toposort = lambda: e.nodes
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling]) if any(isinstance(input, graph.Value) for input in node.inputs):
desc = None
else:
desc = (node.op,
tuple(input.type for input in node.inputs),
tuple(input.type for input in node.inputs),
tuple(output in no_recycling for output in node.outputs))
if desc in self.__cache__:
cl = self.__cache__[desc]
else:
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if desc is not None:
self.__cache__[desc] = cl
thunk, node_input_filters, node_output_filters = cl.make_thunk( thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage, input_storage = node_input_storage,
output_storage = node_output_storage) output_storage = node_output_storage)
...@@ -828,9 +846,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -828,9 +846,9 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = utils.difference(no_recycling, input_storage) no_recycling = utils.difference(no_recycling, input_storage)
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
f = link.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [link.Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order thunks, order
......
...@@ -164,32 +164,32 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -164,32 +164,32 @@ def map_storage(env, order, input_storage, output_storage):
return input_storage, output_storage, storage_map return input_storage, output_storage, storage_map
def streamline(env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except:
raise_with_op(node)
else:
def f():
for x in no_recycling:
x[0] = None
def g():
for thunk, node in zip(thunks, order):
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f
class LocalLinker(Linker): class LocalLinker(Linker):
""" """
Useful base class for L{Linker}s which keep all nodes in the graph, and run a Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node. thunk associated with each node.
""" """
def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except:
raise_with_op(node)
else:
def f():
for x in no_recycling:
x[0] = None
def g():
for thunk, node in zip(thunks, order):
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f
def make_thunk(self, profiler = None, input_storage = None, output_storage = None): def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler, return self.make_all(profiler = profiler,
...@@ -248,7 +248,7 @@ class PerformLinker(LocalLinker): ...@@ -248,7 +248,7 @@ class PerformLinker(LocalLinker):
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
......
...@@ -323,10 +323,12 @@ class ScalarOp(Op): ...@@ -323,10 +323,12 @@ class ScalarOp(Op):
raise AbstractFunctionError() raise AbstractFunctionError()
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.output_types_preference == other.output_types_preference return type(self) == type(other) \
and getattr(self, 'output_types_preference', None) \
== getattr(other, 'output_types_preference', None)
def __hash__(self): def __hash__(self):
return hash(self.output_types_preference) return hash(getattr(self, 'output_types_preference', 0))
def __str__(self): def __str__(self):
if hasattr(self, 'name') and self.name: if hasattr(self, 'name') and self.name:
...@@ -805,3 +807,9 @@ class Composite(ScalarOp): ...@@ -805,3 +807,9 @@ class Composite(ScalarOp):
**sub) **sub)
d['name'] = name d['name'] = name
return self._c_code % d return self._c_code % d
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论