提交 920d123d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1517 from lamblin/init_code

Add init_code mechanism to register code in module init function
...@@ -57,6 +57,11 @@ There are less methods to define for an Op than for a Type: ...@@ -57,6 +57,11 @@ There are less methods to define for an Op than for a Type:
Allows you to specify special g++ arguments to add/exclude Allows you to specify special g++ arguments to add/exclude
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
.. method:: c_support_code() .. method:: c_support_code()
Allows you to specify helper functions/structs that the Allows you to specify helper functions/structs that the
......
...@@ -90,6 +90,14 @@ the most important ones: ...@@ -90,6 +90,14 @@ the most important ones:
Allows to specify special compiler arguments to add/exclude. Allows to specify special compiler arguments to add/exclude.
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
For instance, if a type depends on NumPy's C API, then
``'import_array();'`` has to be among the snippets returned
by ``c_init_code()``.
.. method:: c_support_code() .. method:: c_support_code()
Allows to add helper functions/structs that the :ref:`type` needs. Allows to add helper functions/structs that the :ref:`type` needs.
......
...@@ -742,7 +742,7 @@ class CLinker(link.Linker): ...@@ -742,7 +742,7 @@ class CLinker(link.Linker):
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
ret = list(set(ret)) # to remove duplicate ret = utils.uniq(ret) # to remove duplicate
# The args set by the compiler include the user flags. We do not want # The args set by the compiler include the user flags. We do not want
# to reorder them # to reorder them
ret += c_compiler.compile_args() ret += c_compiler.compile_args()
...@@ -772,7 +772,22 @@ class CLinker(link.Linker): ...@@ -772,7 +772,22 @@ class CLinker(link.Linker):
ret += x.c_headers() ret += x.c_headers()
except utils.MethodNotDefined: except utils.MethodNotDefined:
pass pass
return list(set(ret)) return utils.uniq(ret)
def init_code(self):
"""
Return a list of code snippets that have to be inserted
in the module initialization code.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
try:
ret += x.c_init_code()
except utils.MethodNotDefined:
pass
return utils.uniq(ret)
def c_compiler(self): def c_compiler(self):
c_compiler = None c_compiler = None
...@@ -809,7 +824,7 @@ class CLinker(link.Linker): ...@@ -809,7 +824,7 @@ class CLinker(link.Linker):
ret += x.c_header_dirs() ret += x.c_header_dirs()
except utils.MethodNotDefined: except utils.MethodNotDefined:
pass pass
return list(set(ret)) return utils.uniq(ret)
def libraries(self): def libraries(self):
"""WRITEME """WRITEME
...@@ -825,7 +840,7 @@ class CLinker(link.Linker): ...@@ -825,7 +840,7 @@ class CLinker(link.Linker):
ret += x.c_libraries() ret += x.c_libraries()
except utils.MethodNotDefined: except utils.MethodNotDefined:
pass pass
return list(set(ret)) return utils.uniq(ret)
def lib_dirs(self): def lib_dirs(self):
"""WRITEME """WRITEME
...@@ -841,7 +856,7 @@ class CLinker(link.Linker): ...@@ -841,7 +856,7 @@ class CLinker(link.Linker):
ret += x.c_lib_dirs() ret += x.c_lib_dirs()
except utils.MethodNotDefined: except utils.MethodNotDefined:
pass pass
return list(set(ret)) return utils.uniq(ret)
def __compile__(self, input_storage=None, def __compile__(self, input_storage=None,
output_storage=None, keep_lock=False): output_storage=None, keep_lock=False):
...@@ -1277,6 +1292,8 @@ class CLinker(link.Linker): ...@@ -1277,6 +1292,8 @@ class CLinker(link.Linker):
mod.add_function(instantiate) mod.add_function(instantiate)
for header in self.headers(): for header in self.headers():
mod.add_include(header) mod.add_include(header)
for init_code_block in self.init_code():
mod.add_init_code(init_code_block)
return mod return mod
......
...@@ -144,12 +144,7 @@ class DynamicModule(object): ...@@ -144,12 +144,7 @@ class DynamicModule(object):
self.support_code = [] self.support_code = []
self.functions = [] self.functions = []
self.includes = ["<Python.h>", "<iostream>"] self.includes = ["<Python.h>", "<iostream>"]
self.init_blocks = []
#TODO: this should come from TensorType
self.includes.append('<numpy/arrayobject.h>')
#TODO: from TensorType
self.init_blocks = ['import_array();']
def print_methoddef(self, stream): def print_methoddef(self, stream):
print >> stream, "static PyMethodDef MyMethods[] = {" print >> stream, "static PyMethodDef MyMethods[] = {"
......
...@@ -175,6 +175,17 @@ class CLinkerObject(object): ...@@ -175,6 +175,17 @@ class CLinkerObject(object):
""" """
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__) raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__)
def c_init_code(self):
"""
Optional: return a list of code snippets to be inserted in module
initialization.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code", type(self),
self.__class__.__name__)
class CLinkerOp(CLinkerObject): class CLinkerOp(CLinkerObject):
""" """
......
...@@ -158,6 +158,9 @@ class Scalar(Type): ...@@ -158,6 +158,9 @@ class Scalar(Type):
def c_headers(self): def c_headers(self):
l = ['<math.h>'] l = ['<math.h>']
# These includes are needed by Scalar and TensorType,
# we declare them here and they will be re-used by TensorType
l.append('<numpy/arrayobject.h>')
l.append('<numpy/arrayscalars.h>') l.append('<numpy/arrayscalars.h>')
if config.lib.amdlibm: if config.lib.amdlibm:
l += ['<amdlibm.h>'] l += ['<amdlibm.h>']
...@@ -431,23 +434,11 @@ class Scalar(Type): ...@@ -431,23 +434,11 @@ class Scalar(Type):
else: else:
return "" return ""
def c_init_code(self):
return ["import_array();"]
def c_code_cache_version(self): def c_code_cache_version(self):
# Fix gh-1510, use half_nbits float instead of nbits return (12, numpy.__version__)
return (11, numpy.__version__)
# Use the correct type checking and conversion functions
return (10, numpy.__version__)
# Make operators work with 64 and 128 arguments at the same time
return (9, numpy.__version__)
# put const around operators and added unary '-' operator
return (8, numpy.__version__)
# no need to put lib.amdlibm here as c_compile_args() are put
# in the key.
return (7,) # make complex c code optional
return (6,) # added implemeentations of operators that work
# with scalar arguments
return (5,) # added constructors to theano_complex class
return (4,) # explicit T given in specialization of operator=
# lines. This makes it compile with open64
def get_shape_info(self, obj): def get_shape_info(self, obj):
return obj.itemsize return obj.itemsize
......
...@@ -7,6 +7,7 @@ import theano ...@@ -7,6 +7,7 @@ import theano
from theano import config from theano import config
from theano.gof import Constant, hashtype, Type, Variable from theano.gof import Constant, hashtype, Type, Variable
from theano.gof.python25 import any from theano.gof.python25 import any
from theano.gof.utils import MethodNotDefined
from theano import scalar as scal from theano import scalar as scal
...@@ -415,7 +416,7 @@ class TensorType(Type): ...@@ -415,7 +416,7 @@ class TensorType(Type):
#"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) #"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub): def c_declare(self, name, sub):
"""Override `CLinkerOp.c_declare` """ """Override `CLinkerType.c_declare` """
return """ return """
PyArrayObject* %(name)s; PyArrayObject* %(name)s;
int type_num_%(name)s; int type_num_%(name)s;
...@@ -423,14 +424,14 @@ class TensorType(Type): ...@@ -423,14 +424,14 @@ class TensorType(Type):
""" % dict(sub, name=name, dtype=self.dtype_specs()[1]) """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
def c_init(self, name, sub): def c_init(self, name, sub):
"""Override `CLinkerOp.c_init` """ """Override `CLinkerType.c_init` """
return """ return """
%(name)s = NULL; %(name)s = NULL;
type_num_%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub): def c_extract(self, name, sub):
"""Override `CLinkerOp.c_extract` """ """Override `CLinkerType.c_extract` """
return """ return """
%(name)s = NULL; %(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
...@@ -484,7 +485,7 @@ class TensorType(Type): ...@@ -484,7 +485,7 @@ class TensorType(Type):
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
"""Override `CLinkerOp.c_cleanup` """ """Override `CLinkerType.c_cleanup` """
return """ return """
if (%(name)s) { if (%(name)s) {
Py_XDECREF(%(name)s); Py_XDECREF(%(name)s);
...@@ -492,7 +493,7 @@ class TensorType(Type): ...@@ -492,7 +493,7 @@ class TensorType(Type):
""" % locals() """ % locals()
def c_sync(self, name, sub): def c_sync(self, name, sub):
"""Override `CLinkerOp.c_sync` """ """Override `CLinkerType.c_sync` """
fail = sub['fail'] fail = sub['fail']
type_num = self.dtype_specs()[2] type_num = self.dtype_specs()[2]
return """ return """
...@@ -535,7 +536,7 @@ class TensorType(Type): ...@@ -535,7 +536,7 @@ class TensorType(Type):
""" % locals() """ % locals()
def c_headers(self): def c_headers(self):
"""Override `CLinkerOp.c_headers` """ """Override `CLinkerObject.c_headers` """
return scal.Scalar(self.dtype).c_headers() return scal.Scalar(self.dtype).c_headers()
def c_libraries(self): def c_libraries(self):
...@@ -545,13 +546,16 @@ class TensorType(Type): ...@@ -545,13 +546,16 @@ class TensorType(Type):
return scal.Scalar(self.dtype).c_compile_args() return scal.Scalar(self.dtype).c_compile_args()
def c_support_code(self): def c_support_code(self):
"""Override `CLinkerOp.c_support_code` """ """Override `CLinkerObject.c_support_code` """
return scal.Scalar(self.dtype).c_support_code() return scal.Scalar(self.dtype).c_support_code()
def c_init_code(self):
return scal.Scalar(self.dtype).c_init_code()
def c_code_cache_version(self): def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version() scalar_version = scal.Scalar(self.dtype).c_code_cache_version()
if scalar_version: if scalar_version:
return (9,) + scalar_version return (10,) + scalar_version
else: else:
return () return ()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论