提交 4102ff9a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Add a base make_thunk for any node.

上级 2c86e528
......@@ -10,6 +10,11 @@ from theano import config
import graph
import numpy
import utils
import logging
from theano import config
from env import Env
import graph
import cc
class CLinkerObject(object):
......@@ -414,4 +419,77 @@ class PureOp(object):
class Op(utils.object2, PureOp, CLinkerOp):
"""Convenience class to bundle `PureOp` and `CLinkerOp`"""
pass
def __new__(cls, *args, **kwargs):
# this function exists to silently and transparently ensure that all
# existing Ops get a _op_use_c_code attribute
obj = object.__new__(cls, *args, **kwargs)
if not hasattr(obj, '_op_use_c_code'):
obj._op_use_c_code = True
return obj
def __init__(self, use_c_code=True):
self._op_use_c_code = use_c_code
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
"""
logger = logging.getLogger('theano.Op')
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
#logger.debug('Compiling node %i of graph' % node_idx)
if self._op_use_c_code:
try:
e = Env(*graph.clone(node.inputs, node.outputs))
e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs)
if old_o in no_recycling]
cl = cc.CLinker().accept(e,
no_recycling=e_no_recycling)
logger.debug('Trying CLinker.make_thunk')
fill_storage, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
def rval():
fill_storage()
for o in node.outputs:
compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
except (NotImplementedError, utils.MethodNotDefined):
logger.debug('Falling back on perform')
# condition: either there was no c_code, or it failed
p = node.op.perform
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
rval.lazy = False
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论