提交 6347cfc5 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optimization for OpWiseCLinker

上级 edad54d2
......@@ -140,7 +140,7 @@ class DimShuffle(Op):
and self.new_order == other.new_order \
and self.input_broadcastable == other.input_broadcastable
def __hash__(self, other):
def __hash__(self):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self):
......@@ -276,7 +276,7 @@ class Elemwise(Op):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern
def __hash__(self):
return hash(self.scalar_op) ^ hash(self.inplace_pattern)
return hash(self.scalar_op) ^ hash(tuple(self.inplace_pattern.items()))
def __str__(self):
if self.name is None:
......
......@@ -529,18 +529,6 @@ class CLinker(link.Linker):
# (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
return self.struct_code
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
else:
return self.tasks[failure_code - n]
def support_code(self):
"""
......@@ -653,19 +641,7 @@ class CLinker(link.Linker):
first_output = ostor[0].data
"""
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
def execute():
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = self.find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace
return execute, in_storage, out_storage
return _execute(cthunk, self.init_tasks, self.tasks), in_storage, out_storage
def cthunk_factory(self, error_storage, in_storage, out_storage):
"""
......@@ -762,6 +738,33 @@ class CLinker(link.Linker):
return ret
def _execute(cthunk, init_tasks, tasks):
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [init_tasks, tasks][failure_code % 2][failure_code/2]
else:
return tasks[failure_code - n]
def execute():
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace
return execute
class OpWiseCLinker(link.LocalLinker):
"""
......@@ -779,6 +782,8 @@ class OpWiseCLinker(link.LocalLinker):
associated to it during the computation (to avoid reusing it).
"""
__cache__ = {}
def __init__(self, fallback_on_perform = True):
self.env = None
self.fallback_on_perform = fallback_on_perform
......@@ -805,7 +810,20 @@ class OpWiseCLinker(link.LocalLinker):
try:
e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if any(isinstance(input, graph.Value) for input in node.inputs):
desc = None
else:
desc = (node.op,
tuple(input.type for input in node.inputs),
tuple(input.type for input in node.inputs),
tuple(output in no_recycling for output in node.outputs))
if desc in self.__cache__:
cl = self.__cache__[desc]
else:
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if desc is not None:
self.__cache__[desc] = cl
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
......@@ -828,9 +846,9 @@ class OpWiseCLinker(link.LocalLinker):
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
f = link.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [link.Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order
......
......@@ -164,32 +164,32 @@ def map_storage(env, order, input_storage, output_storage):
return input_storage, output_storage, storage_map
def streamline(env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except:
raise_with_op(node)
else:
def f():
for x in no_recycling:
x[0] = None
def g():
for thunk, node in zip(thunks, order):
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f
class LocalLinker(Linker):
"""
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def streamline(self, env, thunks, order, no_recycling = [], profiler = None):
if profiler is None:
def f():
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order):
thunk()
except:
raise_with_op(node)
else:
def f():
for x in no_recycling:
x[0] = None
def g():
for thunk, node in zip(thunks, order):
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
......@@ -248,7 +248,7 @@ class PerformLinker(LocalLinker):
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
f = streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
......
......@@ -323,10 +323,12 @@ class ScalarOp(Op):
raise AbstractFunctionError()
def __eq__(self, other):
return type(self) == type(other) and self.output_types_preference == other.output_types_preference
return type(self) == type(other) \
and getattr(self, 'output_types_preference', None) \
== getattr(other, 'output_types_preference', None)
def __hash__(self):
return hash(self.output_types_preference)
return hash(getattr(self, 'output_types_preference', 0))
def __str__(self):
if hasattr(self, 'name') and self.name:
......@@ -805,3 +807,9 @@ class Composite(ScalarOp):
**sub)
d['name'] = name
return self._c_code % d
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论