提交 618a113e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Factor out the code to make C and python thunks and reuse it in DebugMode.

上级 0144d305
......@@ -1713,29 +1713,8 @@ class _Linker(gof.link.LocalLinker):
if not isinstance(node.op, gof.op.Op):
raise utils.MethodNotDefined()
# Don't try to test the C code for float16 if not
# tagged ok.
if (not getattr(node.op, '_f16_ok', False) and
(any(getattr(i, 'dtype', '') == 'float16'
for i in node.inputs) or
any(getattr(o, 'dtype', '') == 'float16'
for o in node.outputs))):
raise utils.MethodNotDefined()
e = FunctionGraph(node.inputs, node.outputs)
# The toposort isn't a stochastic order as it contain only one node.
e.toposort = lambda: list(e.apply_nodes)
# Specifically... e.nodes is a set, but of only 1 element
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs,
node.outputs)
if r2 in no_recycling])
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage=node_input_storage,
output_storage=node_output_storage)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk = node.op.make_c_thunk(node, storage_map, compute_map,
no_recycling)
thunks_c.append(thunk)
except (NotImplementedError, utils.MethodNotDefined):
thunks_c.append(None)
......@@ -1745,20 +1724,8 @@ class _Linker(gof.link.LocalLinker):
# consider that we don't have a python implementation
if ((self.maker.mode.check_py_code or thunks_c[-1] is None) and
node.op.perform.func_code != gof.op.PureOp.perform.func_code):
p = node.op.perform
ctx = node.run_context()
if ctx is graph.NoContext:
thunk = (lambda p=p, i=node_input_storage,
o=node_output_storage,
n=node: p(n, [x[0] for x in i], o))
else:
ctx_val = node.context_type.filter(ctx)
thunk = (lambda p=p, i=node_input_storage,
o=node_output_storage, ctx=ctx_val,
n=node: p(n, [x[0] for x in i], o, ctx))
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk.perform = p
thunk = node.op.make_py_thunk(node, storage_map, compute_map,
no_recycling)
thunks_py.append(thunk)
else:
thunks_py.append(None)
......
......@@ -699,78 +699,55 @@ class Op(utils.object2, PureOp, CLinkerOp):
else:
return NotImplemented
def make_thunk(self, node, storage_map, compute_map, no_recycling):
def make_c_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
Like make_thunk, but will only try to make a C thunk.
"""
logger = logging.getLogger('theano.gof.op.Op')
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
if self._op_use_c_code:
try:
# float16 get special treatment since running
# unprepared C code will get bad results.
if not getattr(self, '_f16_ok', False):
def is_f16(t):
return getattr(t, 'dtype', '') == 'float16'
if (any(is_f16(i.type) for i in node.inputs) or
any(is_f16(o.type) for o in node.outputs)):
print ("Disabling C code for %s due to unsupported "
"float16" % (self,))
raise NotImplementedError("float16")
e = FunctionGraph(node.inputs, node.outputs)
e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs)
if old_o in no_recycling]
cl = theano.gof.cc.CLinker().accept(e,
no_recycling=e_no_recycling)
logger.debug('Trying CLinker.make_thunk')
outputs = cl.make_thunk(input_storage=node_input_storage,
output_storage=node_output_storage)
fill_storage, node_input_filters, node_output_filters = outputs
def rval():
fill_storage()
for o in node.outputs:
compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
# the next line does nothing, but pyflakes is too
# stupid to realize the def rval below is not a
# redefinition unless I include this
del rval
except (NotImplementedError, utils.MethodNotDefined):
logger.debug('Falling back on perform')
# float16 gets special treatment since running
# unprepared C code will get bad results.
if not getattr(self, '_f16_ok', False):
def is_f16(t):
return getattr(t, 'dtype', '') == 'float16'
if (any(is_f16(i.type) for i in node.inputs) or
any(is_f16(o.type) for o in node.outputs)):
print ("Disabling C code for %s due to unsupported "
"float16" % (self,))
raise NotImplementedError("float16")
e = FunctionGraph(node.inputs, node.outputs)
e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs)
if old_o in no_recycling]
cl = theano.gof.cc.CLinker().accept(e,
no_recycling=e_no_recycling)
logger.debug('Trying CLinker.make_thunk')
outputs = cl.make_thunk(input_storage=node_input_storage,
output_storage=node_output_storage)
fill_storage, node_input_filters, node_output_filters = outputs
def rval():
fill_storage()
for o in node.outputs:
compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
# condition: either there was no c_code, or it failed
def make_py_thunk(self, node, storage_map, compute_map, no_recycling):
"""
Like make_thunk() but only makes python thunks.
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
p = node.op.perform
......@@ -798,6 +775,39 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval.lazy = False
return rval
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
logger = logging.getLogger('theano.gof.op.Op')
if self._op_use_c_code:
try:
return self.make_c_thunk(node, storage_map, compute_map,
no_recycling)
except (NotImplementedError, utils.MethodNotDefined):
logger.debug('Falling back on perform')
# condition: either there was no c_code, or it failed
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def get_test_value(v):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论