提交 f9b8b7ab authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

merge

...@@ -89,8 +89,8 @@ _gemm_code = { 'f': _gemm_code_template % { 'gemm':'cblas_sgemm', 'dtype':'float ...@@ -89,8 +89,8 @@ _gemm_code = { 'f': _gemm_code_template % { 'gemm':'cblas_sgemm', 'dtype':'float
def _gemm_rank2(a, x, y, b, z): def _gemm_rank2(a, x, y, b, z):
weave.inline(_gemm_code[z.dtype.char], weave.inline(_gemm_code[z.dtype.char],
['a', 'x', 'y', 'b', 'z'], ['a', 'x', 'y', 'b', 'z'],
headers=['<gsl/gsl_cblas.h>'], headers=['"/home/bergstra/cvs/lgcm/omega/cblas.h"'],
libraries=['cblas','goto', 'g2c']) libraries=['mkl', 'm'])
def _gemm(a, x, y, b, z): def _gemm(a, x, y, b, z):
if len(x.shape) == 2 and len(y.shape) == 2: if len(x.shape) == 2 and len(y.shape) == 2:
......
import os # for building the location of the .omega/omega_compiled cache directory
import sys # for adding the inline code cache to the include path
import os import os
import sys import sys
...@@ -40,8 +42,22 @@ literals_db = {} ...@@ -40,8 +42,22 @@ literals_db = {}
literals_id_db = weakref.WeakValueDictionary() literals_id_db = weakref.WeakValueDictionary()
#input floating point scalars will be cast to arrays of this type #input floating point scalars will be cast to arrays of this type
# see TRAC(#31)
default_input_scalar_dtype = 'float64' default_input_scalar_dtype = 'float64'
# BLAS Support
# These should be used by dependent modules to link blas functions.
# - used by dot(), gemm()
_blas_headers = ['"/home/bergstra/cvs/lgcm/omega/cblas.h"']
_blas_libs = ['mkl', 'm']
# WEAVE CACHE
#_home_omega = os.path.join(os.getenv('HOME'), '.omega')
_home_omega = os.path.join('/home/bergstra/.omega')
_compiled = 'omega_compiled'
_home_omega_compiled = os.path.join(_home_omega, _compiled)
sys.path.append(_home_omega) # J - is this a good idea??
def input(x): def input(x):
#NB: #NB:
# - automatically casting int to float seems wrong. # - automatically casting int to float seems wrong.
...@@ -244,7 +260,7 @@ class omega_op(gof.PythonOp): ...@@ -244,7 +260,7 @@ class omega_op(gof.PythonOp):
type_converters = converters) type_converters = converters)
instantiate.customize.add_support_code(self.c_support_code() + struct) instantiate.customize.add_support_code(self.c_support_code() + struct)
instantiate.customize.add_extra_compile_arg("-O3") instantiate.customize.add_extra_compile_arg("-O3")
instantiate.customize.add_extra_compile_arg("-ffast-math") instantiate.customize.add_extra_compile_arg("-ffast-math") #TODO: make this optional, say by passing args to c_thunk_factory?
instantiate.customize.add_extra_compile_arg("-falign-loops=4") instantiate.customize.add_extra_compile_arg("-falign-loops=4")
# instantiate.customize.add_extra_compile_arg("-mfpmath=sse") # instantiate.customize.add_extra_compile_arg("-mfpmath=sse")
for header in self.c_headers(): for header in self.c_headers():
...@@ -253,12 +269,8 @@ class omega_op(gof.PythonOp): ...@@ -253,12 +269,8 @@ class omega_op(gof.PythonOp):
instantiate.customize.add_library(lib) instantiate.customize.add_library(lib)
mod.add_function(instantiate) mod.add_function(instantiate)
module_dir = os.path.expanduser('~/.omega/compiled') mod.compile(location = _home_omega_compiled)
sys.path.insert(0, module_dir) module = __import__("%s.%s" % (_compiled, module_name), {}, {}, [module_name])
mod.compile(location = module_dir)
module = __import__("%s" % module_name) #, {}, {}, [module_name])
sys.path = sys.path[1:]
def creator(): def creator():
return module.instantiate(*[x.data for x in self.inputs + self.outputs]) return module.instantiate(*[x.data for x in self.inputs + self.outputs])
...@@ -381,8 +393,8 @@ class elemwise(omega_op): ...@@ -381,8 +393,8 @@ class elemwise(omega_op):
for oname in onames: for oname in onames:
if oname not in lonames: if oname not in lonames:
raise Exception("cannot infer a specification automatically for variable " \ raise Exception("cannot infer a specification automatically for variable " \
"%s because it is not part of the elementwise loop - "\ "%s.%s because it is not part of the elementwise loop - "\
"please override the specs method" % oname) "please override the specs method" % (self.__class__.__name__, oname))
shape, dtype = None, None shape, dtype = None, None
for iname, input in zip(inames, self.inputs): for iname, input in zip(inames, self.inputs):
if iname in linames: if iname in linames:
...@@ -1054,11 +1066,13 @@ class gemm(omega_op, inplace): ...@@ -1054,11 +1066,13 @@ class gemm(omega_op, inplace):
def alloc(self, except_list): def alloc(self, except_list):
self.outputs[0].data = self.inputs[0].data self.outputs[0].data = self.inputs[0].data
def c_headers(self): def c_headers(self):
return ["<gsl/gsl_cblas.h>"] return _blas_headers
def c_libs(self): def c_libs(self):
return ["cblas", "atlas", "g2c"] return _blas_libs
def c_impl((_zin, _a, _x, _y, _b), (_z,)): def c_impl((_zin, _a, _x, _y, _b), (_z,)):
return blas_code.gemm_xyz(str(_a), str(_b)) return blas_code.gemm_xyz(
'((_a->descr->type_num == PyArray_FLOAT) ? (float*)_a->data : (double*)_a->data)[0]',
'((_b->descr->type_num == PyArray_FLOAT) ? (float*)_b->data : (double*)_b->data)[0]')
## Transposition ## ## Transposition ##
......
...@@ -16,7 +16,6 @@ def optimizer(lst): ...@@ -16,7 +16,6 @@ def optimizer(lst):
seq_opt = gof.SeqOptimizer(begin + lst + end) seq_opt = gof.SeqOptimizer(begin + lst + end)
return gof.PythonOpt(gof.MergeOptMerge(seq_opt)) return gof.PythonOpt(gof.MergeOptMerge(seq_opt))
if 0: if 0:
optimizer_begin = gof.SeqOptimizer([opt for name, opt in [ optimizer_begin = gof.SeqOptimizer([opt for name, opt in [
['double_transpose_eliminator', pattern_opt((transpose, (transpose, 'x')), 'x')], ['double_transpose_eliminator', pattern_opt((transpose, (transpose, 'x')), 'x')],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论