提交 920d123d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1517 from lamblin/init_code

Add init_code mechanism to register code in module init function
......@@ -57,6 +57,11 @@ There are less methods to define for an Op than for a Type:
Allows you to specify special g++ arguments to add/exclude
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
.. method:: c_support_code()
Allows you to specify helper functions/structs that the
......
......@@ -90,6 +90,14 @@ the most important ones:
Allows to specify special compiler arguments to add/exclude.
.. method:: c_init_code()
Allows you to specify code that will be executed once when the
module is initialized, before anything else is executed.
For instance, if a type depends on NumPy's C API, then
``'import_array();'`` has to be among the snippets returned
by ``c_init_code()``.
.. method:: c_support_code()
Allows to add helper functions/structs that the :ref:`type` needs.
......
......@@ -742,7 +742,7 @@ class CLinker(link.Linker):
c_compiler = self.c_compiler()
ret = list(set(ret)) # to remove duplicate
ret = utils.uniq(ret) # to remove duplicate
# The args set by the compiler include the user flags. We do not want
# to reorder them
ret += c_compiler.compile_args()
......@@ -772,7 +772,22 @@ class CLinker(link.Linker):
ret += x.c_headers()
except utils.MethodNotDefined:
pass
return list(set(ret))
return utils.uniq(ret)
def init_code(self):
"""
Return a list of code snippets that have to be inserted
in the module initialization code.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
try:
ret += x.c_init_code()
except utils.MethodNotDefined:
pass
return utils.uniq(ret)
def c_compiler(self):
c_compiler = None
......@@ -809,7 +824,7 @@ class CLinker(link.Linker):
ret += x.c_header_dirs()
except utils.MethodNotDefined:
pass
return list(set(ret))
return utils.uniq(ret)
def libraries(self):
"""WRITEME
......@@ -825,7 +840,7 @@ class CLinker(link.Linker):
ret += x.c_libraries()
except utils.MethodNotDefined:
pass
return list(set(ret))
return utils.uniq(ret)
def lib_dirs(self):
"""WRITEME
......@@ -841,7 +856,7 @@ class CLinker(link.Linker):
ret += x.c_lib_dirs()
except utils.MethodNotDefined:
pass
return list(set(ret))
return utils.uniq(ret)
def __compile__(self, input_storage=None,
output_storage=None, keep_lock=False):
......@@ -1277,6 +1292,8 @@ class CLinker(link.Linker):
mod.add_function(instantiate)
for header in self.headers():
mod.add_include(header)
for init_code_block in self.init_code():
mod.add_init_code(init_code_block)
return mod
......
......@@ -144,12 +144,7 @@ class DynamicModule(object):
self.support_code = []
self.functions = []
self.includes = ["<Python.h>", "<iostream>"]
#TODO: this should come from TensorType
self.includes.append('<numpy/arrayobject.h>')
#TODO: from TensorType
self.init_blocks = ['import_array();']
self.init_blocks = []
def print_methoddef(self, stream):
print >> stream, "static PyMethodDef MyMethods[] = {"
......
......@@ -175,6 +175,17 @@ class CLinkerObject(object):
"""
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__)
def c_init_code(self):
"""
Optional: return a list of code snippets to be inserted in module
initialization.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code", type(self),
self.__class__.__name__)
class CLinkerOp(CLinkerObject):
"""
......
......@@ -158,6 +158,9 @@ class Scalar(Type):
def c_headers(self):
l = ['<math.h>']
# These includes are needed by Scalar and TensorType,
# we declare them here and they will be re-used by TensorType
l.append('<numpy/arrayobject.h>')
l.append('<numpy/arrayscalars.h>')
if config.lib.amdlibm:
l += ['<amdlibm.h>']
......@@ -431,23 +434,11 @@ class Scalar(Type):
else:
return ""
def c_init_code(self):
return ["import_array();"]
def c_code_cache_version(self):
# Fix gh-1510, use half_nbits float instead of nbits
return (11, numpy.__version__)
# Use the correct type checking and conversion functions
return (10, numpy.__version__)
# Make operators work with 64 and 128 arguments at the same time
return (9, numpy.__version__)
# put const around operators and added unary '-' operator
return (8, numpy.__version__)
# no need to put lib.amdlibm here as c_compile_args() are put
# in the key.
return (7,) # make complex c code optional
return (6,) # added implemeentations of operators that work
# with scalar arguments
return (5,) # added constructors to theano_complex class
return (4,) # explicit T given in specialization of operator=
# lines. This makes it compile with open64
return (12, numpy.__version__)
def get_shape_info(self, obj):
return obj.itemsize
......
......@@ -7,6 +7,7 @@ import theano
from theano import config
from theano.gof import Constant, hashtype, Type, Variable
from theano.gof.python25 import any
from theano.gof.utils import MethodNotDefined
from theano import scalar as scal
......@@ -415,7 +416,7 @@ class TensorType(Type):
#"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub):
"""Override `CLinkerOp.c_declare` """
"""Override `CLinkerType.c_declare` """
return """
PyArrayObject* %(name)s;
int type_num_%(name)s;
......@@ -423,14 +424,14 @@ class TensorType(Type):
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
def c_init(self, name, sub):
"""Override `CLinkerOp.c_init` """
"""Override `CLinkerType.c_init` """
return """
%(name)s = NULL;
type_num_%(name)s = %(type_num)s;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub):
"""Override `CLinkerOp.c_extract` """
"""Override `CLinkerType.c_extract` """
return """
%(name)s = NULL;
if (py_%(name)s == Py_None) {
......@@ -484,7 +485,7 @@ class TensorType(Type):
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_cleanup(self, name, sub):
"""Override `CLinkerOp.c_cleanup` """
"""Override `CLinkerType.c_cleanup` """
return """
if (%(name)s) {
Py_XDECREF(%(name)s);
......@@ -492,7 +493,7 @@ class TensorType(Type):
""" % locals()
def c_sync(self, name, sub):
"""Override `CLinkerOp.c_sync` """
"""Override `CLinkerType.c_sync` """
fail = sub['fail']
type_num = self.dtype_specs()[2]
return """
......@@ -535,7 +536,7 @@ class TensorType(Type):
""" % locals()
def c_headers(self):
"""Override `CLinkerOp.c_headers` """
"""Override `CLinkerObject.c_headers` """
return scal.Scalar(self.dtype).c_headers()
def c_libraries(self):
......@@ -545,13 +546,16 @@ class TensorType(Type):
return scal.Scalar(self.dtype).c_compile_args()
def c_support_code(self):
"""Override `CLinkerOp.c_support_code` """
"""Override `CLinkerObject.c_support_code` """
return scal.Scalar(self.dtype).c_support_code()
def c_init_code(self):
return scal.Scalar(self.dtype).c_init_code()
def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version()
if scalar_version:
return (9,) + scalar_version
return (10,) + scalar_version
else:
return ()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论