提交 520f92a4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add init_code mechanism to register code in module init fn

上级 82ccf7da
......@@ -57,6 +57,11 @@ There are less methods to define for an Op than for a Type:
Allows you to specify special g++ arguments to add/exclude
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
.. method:: c_support_code()
Allows you to specify helper functions/structs that the
......
......@@ -90,6 +90,11 @@ the most important ones:
Allows to specify special compiler arguments to add/exclude.
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
.. method:: c_support_code()
Allows to add helper functions/structs that the :ref:`type` needs.
......
......@@ -774,6 +774,21 @@ class CLinker(link.Linker):
pass
return utils.uniq(ret)
def init_code(self):
"""
Return a list of code snippets that have to be inserted
in the module initialization code.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
try:
ret += x.c_init_code()
except utils.MethodNotDefined:
pass
return utils.uniq(ret)
def c_compiler(self):
c_compiler = None
for x in [y.type for y in self.variables] + [
......@@ -1277,6 +1292,8 @@ class CLinker(link.Linker):
mod.add_function(instantiate)
for header in self.headers():
mod.add_include(header)
for init_code_block in self.init_code():
mod.add_init_code(init_code_block)
return mod
......
......@@ -174,6 +174,17 @@ class CLinkerObject(object):
"""
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__)
def c_init_code(self):
"""
Optional: return a list of code snippets to be inserted in module
initialization.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code", type(self),
self.__class__.__name__)
class CLinkerOp(CLinkerObject):
"""
......
......@@ -7,6 +7,7 @@ import theano
from theano import config
from theano.gof import Constant, hashtype, Type, Variable
from theano.gof.python25 import any
from theano.gof.utils import MethodNotDefined
from theano import scalar as scal
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论