提交 d33fa296 authored 作者: Frederic Bastien's avatar Frederic Bastien

added code to use the amdlibm librairy from amd when the theano flags set lib.amdlibm

Not enabled by default as for some case it is 1.17x the original time(slowdown) but in other case it is .07x the original time(13 times faster!)
上级 0dcc62f0
...@@ -374,7 +374,7 @@ default_={ ...@@ -374,7 +374,7 @@ default_={
'ProfileMode.n_apply_to_print':15, 'ProfileMode.n_apply_to_print':15,
'ProfileMode.n_ops_to_print':20, 'ProfileMode.n_ops_to_print':20,
'tensor_opt.local_elemwise_fusion':False, 'tensor_opt.local_elemwise_fusion':False,
'scalar_basic.amdlibm':False, 'lib.amdlibm':False,
} }
......
...@@ -72,6 +72,23 @@ class Scalar(Type): ...@@ -72,6 +72,23 @@ class Scalar(Type):
def values_eq_approx(self, a, b, tolerance = 1e-4): def values_eq_approx(self, a, b, tolerance = 1e-4):
return abs(a - b) / (a+b) < tolerance return abs(a - b) / (a+b) < tolerance
def c_headers(self):
l=['<math.h>']
if utils.config.getboolean('lib.amdlibm'):
l+=['<amdlibm.h>']
return l
def c_libraries(self):
l=[]
if utils.config.getboolean('lib.amdlibm'):
l+=['amdlibm']
return l
def c_compile_args(self):
if utils.config.getboolean('lib.amdlibm'):
return ['-DREPLACE_WITH_AMDLIBM']
else: return []
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype return type(self) == type(other) and other.dtype == self.dtype
...@@ -220,7 +237,7 @@ class Scalar(Type): ...@@ -220,7 +237,7 @@ class Scalar(Type):
def c_code_cache_version(self): def c_code_cache_version(self):
#return () #return ()
return (3,) #explicit T given in specialization of operator= lines. This makes it compile with open64 return (4,) #explicit T given in specialization of operator= lines. This makes it compile with open64
#2, #2,
......
...@@ -545,10 +545,13 @@ class TensorType(Type): ...@@ -545,10 +545,13 @@ class TensorType(Type):
def c_headers(self): def c_headers(self):
"""Override `CLinkerOp.c_headers` """ """Override `CLinkerOp.c_headers` """
return [] return scal.Scalar(self.dtype).c_headers()
def c_libraries(self): def c_libraries(self):
return [] return scal.Scalar(self.dtype).c_libraries()
def c_compile_args(self):
return scal.Scalar(self.dtype).c_compile_args()
def c_support_code(self): def c_support_code(self):
"""Override `CLinkerOp.c_support_code` """ """Override `CLinkerOp.c_support_code` """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论