提交 4102ff9a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Add a base make_thunk for any node.

上级 2c86e528
...@@ -10,6 +10,11 @@ from theano import config ...@@ -10,6 +10,11 @@ from theano import config
import graph import graph
import numpy import numpy
import utils import utils
import logging
from theano import config
from env import Env
import graph
import cc
class CLinkerObject(object): class CLinkerObject(object):
...@@ -414,4 +419,77 @@ class PureOp(object): ...@@ -414,4 +419,77 @@ class PureOp(object):
class Op(utils.object2, PureOp, CLinkerOp): class Op(utils.object2, PureOp, CLinkerOp):
"""Convenience class to bundle `PureOp` and `CLinkerOp`""" """Convenience class to bundle `PureOp` and `CLinkerOp`"""
pass def __new__(cls, *args, **kwargs):
# this function exists to silently and transparently ensure that all
# existing Ops get a _op_use_c_code attribute
obj = object.__new__(cls, *args, **kwargs)
if not hasattr(obj, '_op_use_c_code'):
obj._op_use_c_code = True
return obj
def __init__(self, use_c_code=True):
self._op_use_c_code = use_c_code
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
"""
logger = logging.getLogger('theano.Op')
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
#logger.debug('Compiling node %i of graph' % node_idx)
if self._op_use_c_code:
try:
e = Env(*graph.clone(node.inputs, node.outputs))
e_no_recycling = [new_o
for (new_o, old_o) in zip(e.outputs, node.outputs)
if old_o in no_recycling]
cl = cc.CLinker().accept(e,
no_recycling=e_no_recycling)
logger.debug('Trying CLinker.make_thunk')
fill_storage, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
def rval():
fill_storage()
for o in node.outputs:
compute_map[o][0] = True
rval.cthunk = fill_storage.cthunk
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.lazy = False
return rval
except (NotImplementedError, utils.MethodNotDefined):
logger.debug('Falling back on perform')
# condition: either there was no c_code, or it failed
p = node.op.perform
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
rval.lazy = False
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论