提交 520f92a4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add init_code mechanism to register code in module init fn

上级 82ccf7da
...@@ -57,6 +57,11 @@ There are less methods to define for an Op than for a Type: ...@@ -57,6 +57,11 @@ There are less methods to define for an Op than for a Type:
Allows you to specify special g++ arguments to add/exclude Allows you to specify special g++ arguments to add/exclude
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
.. method:: c_support_code() .. method:: c_support_code()
Allows you to specify helper functions/structs that the Allows you to specify helper functions/structs that the
......
...@@ -90,6 +90,11 @@ the most important ones: ...@@ -90,6 +90,11 @@ the most important ones:
Allows to specify special compiler arguments to add/exclude. Allows to specify special compiler arguments to add/exclude.
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
.. method:: c_support_code() .. method:: c_support_code()
Allows to add helper functions/structs that the :ref:`type` needs. Allows to add helper functions/structs that the :ref:`type` needs.
......
...@@ -774,6 +774,21 @@ class CLinker(link.Linker): ...@@ -774,6 +774,21 @@ class CLinker(link.Linker):
pass pass
return utils.uniq(ret) return utils.uniq(ret)
def init_code(self):
"""
Return a list of code snippets that have to be inserted
in the module initialization code.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
try:
ret += x.c_init_code()
except utils.MethodNotDefined:
pass
return utils.uniq(ret)
def c_compiler(self): def c_compiler(self):
c_compiler = None c_compiler = None
for x in [y.type for y in self.variables] + [ for x in [y.type for y in self.variables] + [
...@@ -1277,6 +1292,8 @@ class CLinker(link.Linker): ...@@ -1277,6 +1292,8 @@ class CLinker(link.Linker):
mod.add_function(instantiate) mod.add_function(instantiate)
for header in self.headers(): for header in self.headers():
mod.add_include(header) mod.add_include(header)
for init_code_block in self.init_code():
mod.add_init_code(init_code_block)
return mod return mod
......
...@@ -174,6 +174,17 @@ class CLinkerObject(object): ...@@ -174,6 +174,17 @@ class CLinkerObject(object):
""" """
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__) raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__)
def c_init_code(self):
"""
Optional: return a list of code snippets to be inserted in module
initialization.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code", type(self),
self.__class__.__name__)
class CLinkerOp(CLinkerObject): class CLinkerOp(CLinkerObject):
""" """
......
...@@ -7,6 +7,7 @@ import theano ...@@ -7,6 +7,7 @@ import theano
from theano import config from theano import config
from theano.gof import Constant, hashtype, Type, Variable from theano.gof import Constant, hashtype, Type, Variable
from theano.gof.python25 import any from theano.gof.python25 import any
from theano.gof.utils import MethodNotDefined
from theano import scalar as scal from theano import scalar as scal
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论