提交 618a113e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Factor out the code to make C and python thunks and reuse it in DebugMode.

上级 0144d305
...@@ -1713,29 +1713,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1713,29 +1713,8 @@ class _Linker(gof.link.LocalLinker):
if not isinstance(node.op, gof.op.Op): if not isinstance(node.op, gof.op.Op):
raise utils.MethodNotDefined() raise utils.MethodNotDefined()
# Don't try to test the C code for float16 if not thunk = node.op.make_c_thunk(node, storage_map, compute_map,
# tagged ok. no_recycling)
if (not getattr(node.op, '_f16_ok', False) and
(any(getattr(i, 'dtype', '') == 'float16'
for i in node.inputs) or
any(getattr(o, 'dtype', '') == 'float16'
for o in node.outputs))):
raise utils.MethodNotDefined()
e = FunctionGraph(node.inputs, node.outputs)
# The toposort isn't a stochastic order as it contain only one node.
e.toposort = lambda: list(e.apply_nodes)
# Specifically... e.nodes is a set, but of only 1 element
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs,
node.outputs)
if r2 in no_recycling])
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage=node_input_storage,
output_storage=node_output_storage)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunks_c.append(thunk) thunks_c.append(thunk)
except (NotImplementedError, utils.MethodNotDefined): except (NotImplementedError, utils.MethodNotDefined):
thunks_c.append(None) thunks_c.append(None)
...@@ -1745,20 +1724,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1745,20 +1724,8 @@ class _Linker(gof.link.LocalLinker):
# consider that we don't have a python implementation # consider that we don't have a python implementation
if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and
node.op.perform.func_code != gof.op.PureOp.perform.func_code): node.op.perform.func_code != gof.op.PureOp.perform.func_code):
p = node.op.perform thunk = node.op.make_py_thunk(node, storage_map, compute_map,
ctx = node.run_context() no_recycling)
if ctx is graph.NoContext:
thunk = (lambda p=p, i=node_input_storage,
o=node_output_storage,
n=node: p(n, [x[0] for x in i], o))
else:
ctx_val = node.context_type.filter(ctx)
thunk = (lambda p=p, i=node_input_storage,
o=node_output_storage, ctx=ctx_val,
n=node: p(n, [x[0] for x in i], o, ctx))
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk.perform = p
thunks_py.append(thunk) thunks_py.append(thunk)
else: else:
thunks_py.append(None) thunks_py.append(None)
......
...@@ -699,37 +699,16 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -699,37 +699,16 @@ class Op(utils.object2, PureOp, CLinkerOp):
else: else:
return NotImplemented return NotImplemented
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_c_thunk(self, node, storage_map, compute_map, no_recycling):
""" """
:param node: something previously returned by self.make_node Like make_thunk, but will only try to make a C thunk.
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
""" """
logger = logging.getLogger('theano.gof.op.Op') logger = logging.getLogger('theano.gof.op.Op')
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
if self._op_use_c_code: # float16 gets special treatment since running
try:
# float16 get special treatment since running
# unprepared C code will get bad results. # unprepared C code will get bad results.
if not getattr(self, '_f16_ok', False): if not getattr(self, '_f16_ok', False):
def is_f16(t): def is_f16(t):
...@@ -741,7 +720,6 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -741,7 +720,6 @@ class Op(utils.object2, PureOp, CLinkerOp):
"float16" % (self,)) "float16" % (self,))
raise NotImplementedError("float16") raise NotImplementedError("float16")
e = FunctionGraph(node.inputs, node.outputs) e = FunctionGraph(node.inputs, node.outputs)
e_no_recycling = [new_o e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs) for (new_o, old_o) in zip(e.outputs, node.outputs)
if old_o in no_recycling] if old_o in no_recycling]
...@@ -763,14 +741,13 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -763,14 +741,13 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.outputs = node_output_storage rval.outputs = node_output_storage
rval.lazy = False rval.lazy = False
return rval return rval
# the next line does nothing, but pyflakes is too
# stupid to realize the def rval below is not a
# redefinition unless I include this
del rval
except (NotImplementedError, utils.MethodNotDefined):
logger.debug('Falling back on perform')
# condition: either there was no c_code, or it failed def make_py_thunk(self, node, storage_map, compute_map, no_recycling):
"""
Like make_thunk() but only makes python thunks.
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
p = node.op.perform p = node.op.perform
...@@ -798,6 +775,39 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -798,6 +775,39 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.lazy = False rval.lazy = False
return rval return rval
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
logger = logging.getLogger('theano.gof.op.Op')
if self._op_use_c_code:
try:
return self.make_c_thunk(node, storage_map, compute_map,
no_recycling)
except (NotImplementedError, utils.MethodNotDefined):
logger.debug('Falling back on perform')
# condition: either there was no c_code, or it failed
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def get_test_value(v): def get_test_value(v):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论