提交 b3c62bb6 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Move header and init_code from DynamicModule to Scalar

上级 520f92a4
......@@ -144,12 +144,7 @@ class DynamicModule(object):
self.support_code = []
self.functions = []
self.includes = ["<Python.h>", "<iostream>"]
#TODO: this should come from TensorType
self.includes.append('<numpy/arrayobject.h>')
#TODO: from TensorType
self.init_blocks = ['import_array();']
self.init_blocks = []
def print_methoddef(self, stream):
print >> stream, "static PyMethodDef MyMethods[] = {"
......
......@@ -158,6 +158,7 @@ class Scalar(Type):
def c_headers(self):
l = ['<math.h>']
l.append('<numpy/arrayobject.h>')
l.append('<numpy/arrayscalars.h>')
if config.lib.amdlibm:
l += ['<amdlibm.h>']
......@@ -413,23 +414,11 @@ class Scalar(Type):
else:
return ""
def c_init_code(self):
return ["import_array();"]
def c_code_cache_version(self):
# Fix gh-1510, use half_nbits float instead of nbits
return (11, numpy.__version__)
# Use the correct type checking and conversion functions
return (10, numpy.__version__)
# Make operators work with 64 and 128 arguments at the same time
return (9, numpy.__version__)
# put const around operators and added unary '-' operator
return (8, numpy.__version__)
# no need to put lib.amdlibm here as c_compile_args() are put
# in the key.
return (7,) # make complex c code optional
return (6,) # added implemeentations of operators that work
# with scalar arguments
return (5,) # added constructors to theano_complex class
return (4,) # explicit T given in specialization of operator=
# lines. This makes it compile with open64
return (12, numpy.__version__)
def get_shape_info(self, obj):
return obj.itemsize
......
......@@ -536,7 +536,7 @@ class TensorType(Type):
""" % locals()
def c_headers(self):
"""Override `CLinkerOp.c_headers` """
"""Override `CLinkerObject.c_headers` """
return scal.Scalar(self.dtype).c_headers()
def c_libraries(self):
......@@ -549,10 +549,13 @@ class TensorType(Type):
"""Override `CLinkerOp.c_support_code` """
return scal.Scalar(self.dtype).c_support_code()
def c_init_code(self):
return scal.Scalar(self.dtype).c_init_code()
def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version()
if scalar_version:
return (9,) + scalar_version
return (10,) + scalar_version
else:
return ()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论