added parameters for compile location, blas headers, blas_libs

上级 5c9aba65
import os # for building the location of the .omega/omega_compiled cache directory
import sys # for adding the inline code cache to the include path
import gof import gof
from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR from gof import current_mode, set_mode, build_mode, eval_mode, build_eval_mode, pop_mode, UNCOMPUTED, UNDEFINED, PythonR
...@@ -37,8 +39,22 @@ literals_db = {} ...@@ -37,8 +39,22 @@ literals_db = {}
literals_id_db = weakref.WeakValueDictionary() literals_id_db = weakref.WeakValueDictionary()
#input floating point scalars will be cast to arrays of this type #input floating point scalars will be cast to arrays of this type
# see TRAC(#31)
default_input_scalar_dtype = 'float64' default_input_scalar_dtype = 'float64'
# BLAS Support
# These should be used by dependent modules to link blas functions.
# - used by dot(), gemm()
_blas_headers = ['"/home/bergstra/cvs/lgcm/omega/cblas.h"']
_blas_libs = ['mkl', 'm']
# WEAVE CACHE
#_home_omega = os.path.join(os.getenv('HOME'), '.omega')
_home_omega = os.path.join('/home/bergstra/.omega')
_compiled = 'omega_compiled'
_home_omega_compiled = os.path.join(_home_omega, _compiled)
sys.path.append(_home_omega) # J - is this a good idea??
def input(x): def input(x):
#NB: #NB:
# - automatically casting int to float seems wrong. # - automatically casting int to float seems wrong.
...@@ -241,7 +257,7 @@ class omega_op(gof.PythonOp): ...@@ -241,7 +257,7 @@ class omega_op(gof.PythonOp):
type_converters = converters) type_converters = converters)
instantiate.customize.add_support_code(self.c_support_code() + struct) instantiate.customize.add_support_code(self.c_support_code() + struct)
instantiate.customize.add_extra_compile_arg("-O3") instantiate.customize.add_extra_compile_arg("-O3")
instantiate.customize.add_extra_compile_arg("-ffast-math") instantiate.customize.add_extra_compile_arg("-ffast-math") #TODO: make this optional, say by passing args to c_thunk_factory?
instantiate.customize.add_extra_compile_arg("-falign-loops=4") instantiate.customize.add_extra_compile_arg("-falign-loops=4")
# instantiate.customize.add_extra_compile_arg("-mfpmath=sse") # instantiate.customize.add_extra_compile_arg("-mfpmath=sse")
for header in self.c_headers(): for header in self.c_headers():
...@@ -250,9 +266,9 @@ class omega_op(gof.PythonOp): ...@@ -250,9 +266,9 @@ class omega_op(gof.PythonOp):
instantiate.customize.add_library(lib) instantiate.customize.add_library(lib)
mod.add_function(instantiate) mod.add_function(instantiate)
mod.compile(location = 'compiled') mod.compile(location = _home_omega_compiled)
module = __import__("compiled.%s" % module_name, {}, {}, [module_name]) module = __import__("%s.%s" % (_compiled, module_name), {}, {}, [module_name])
def creator(): def creator():
return module.instantiate(*[x.data for x in self.inputs + self.outputs]) return module.instantiate(*[x.data for x in self.inputs + self.outputs])
...@@ -397,8 +413,8 @@ class elemwise(omega_op): ...@@ -397,8 +413,8 @@ class elemwise(omega_op):
for oname in onames: for oname in onames:
if oname not in lonames: if oname not in lonames:
raise Exception("cannot infer a specification automatically for variable " \ raise Exception("cannot infer a specification automatically for variable " \
"%s because it is not part of the elementwise loop - "\ "%s.%s because it is not part of the elementwise loop - "\
"please override the specs method" % oname) "please override the specs method" % (self.__class__.__name__, oname))
shape, dtype = None, None shape, dtype = None, None
for iname, input in zip(inames, self.inputs): for iname, input in zip(inames, self.inputs):
if iname in linames: if iname in linames:
...@@ -879,9 +895,9 @@ class dot(omega_op): ...@@ -879,9 +895,9 @@ class dot(omega_op):
shape = (x[2][0], y[2][1]) shape = (x[2][0], y[2][1])
return (numpy.ndarray, upcast(x[1], y[1]), shape) return (numpy.ndarray, upcast(x[1], y[1]), shape)
def c_headers(self): def c_headers(self):
return ["<gsl/gsl_cblas.h>"] return _blas_headers
def c_libs(self): def c_libs(self):
return ["cblas", "atlas", "g2c"] return _blas_libs
def c_impl((_x, _y), (_z, )): def c_impl((_x, _y), (_z, )):
dtype = _x.spec[1] dtype = _x.spec[1]
if dtype.char == 'f': if dtype.char == 'f':
...@@ -989,9 +1005,9 @@ class gemm(omega_op, inplace): ...@@ -989,9 +1005,9 @@ class gemm(omega_op, inplace):
def alloc(self, except_list): def alloc(self, except_list):
self.outputs[0].data = self.inputs[0].data self.outputs[0].data = self.inputs[0].data
def c_headers(self): def c_headers(self):
return ["<gsl/gsl_cblas.h>"] return _blas_headers
def c_libs(self): def c_libs(self):
return ["cblas", "atlas", "g2c"] return _blas_libs
def c_impl((_z, _a, _x, _y, _b), (_zout,)): def c_impl((_z, _a, _x, _y, _b), (_zout,)):
dtype = _x.spec[1] dtype = _x.spec[1]
if dtype.char == 'f': if dtype.char == 'f':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论