提交 4783fe0b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Try to detect and work around bug in MacOS parallel blas

上级 d9a522a5
...@@ -689,7 +689,7 @@ class UsmmCscDense(gof.Op): ...@@ -689,7 +689,7 @@ class UsmmCscDense(gof.Op):
return rval return rval
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1, blas.blas_header_version())
usmm_csc_dense = UsmmCscDense(inplace=False) usmm_csc_dense = UsmmCscDense(inplace=False)
usmm_csc_dense_inplace = UsmmCscDense(inplace=True) usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
...@@ -1519,7 +1519,7 @@ class SamplingDotCSR(gof.Op): ...@@ -1519,7 +1519,7 @@ class SamplingDotCSR(gof.Op):
]) ])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2, ) return (2, blas.blas_header_version())
def c_support_code(self): def c_support_code(self):
return blas.blas_header_text() return blas.blas_header_text()
......
...@@ -145,6 +145,7 @@ from theano.gof.python25 import all, any ...@@ -145,6 +145,7 @@ from theano.gof.python25 import all, any
import theano.scalar import theano.scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import local_dimshuffle_lift from theano.tensor.opt import local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
...@@ -768,7 +769,7 @@ class GemmRelated(Op): ...@@ -768,7 +769,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return (12,) return (12, blas_header_version())
class Gemm(GemmRelated): class Gemm(GemmRelated):
......
from theano import config from theano import config
from theano.tensor.blas import ldflags, blas_header_text from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
...@@ -251,7 +251,7 @@ class CGer(BaseBLAS, Ger): ...@@ -251,7 +251,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (8, blas_header_version())
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
...@@ -578,7 +578,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -578,7 +578,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (10,) return (10, blas_header_version())
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
......
...@@ -2,6 +2,199 @@ ...@@ -2,6 +2,199 @@
There is no standard name or location for this header, so we just insert it ourselves into the C code There is no standard name or location for this header, so we just insert it ourselves into the C code
""" """
import os
import tempfile
import textwrap
import subprocess
import sys
from theano import config
from theano.misc.windows import call_subprocess_Popen
def detect_macos_sdot_bug():
"""
Try to detect a bug in the default BLAS in MacOS.
The problem in Theano has been reported in gh-1240,
the underlying bug has been confirmed in
http://www.macresearch.org/lapackblas-fortran-106#comment-17227.
This function tries to compile code triggering that bug,
and, if necessary, an attempted fix.
Three attributes of this function will be set:
- detect_macos_sdot_bug.tested will be set to True
when this function is called.
- detect_macos_sdot_bug.present will be set to True if the bug is
detected. Its value is returned by the function
- detect_macos_sdot_bug.fix_works will be set to True if the fix was
attempted, and succeeded.
"""
if detect_macos_sdot_bug.tested:
return detect_macos_sdot_bug.present
if sys.platform != 'darwin' or not config.blas.ldflags:
detect_macos_sdot_bug.tested = True
return False
# This code will return -1 if the dot product did not return
# the right value (30.).
flags = config.blas.ldflags.split()
for f in flags:
# Library directories should also be added as rpath,
# so that they can be loaded even if the environment
# variable LD_LIBRARY_PATH does not contain them
if f.startswith('-L'):
flags.append('-Wl,-rpath,' + f[2:])
test_code = textwrap.dedent("""\
extern "C" float sdot_(int*, float*, int*, float*, int*);
int main(int argc, char** argv)
{
int Nx = 5;
int Sx = 1;
float x[5] = {0, 1, 2, 3, 4};
float r = sdot_(&Nx, x, &Sx, x, &Sx);
if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6)
{
return -1;
}
return 0;
}
""")
try:
fd, path = tempfile.mkstemp(suffix='.c',
prefix='detect_macos_sdot_bug_')
exe_path = path[:-2]
dummy_stdin = open(os.devnull)
compilation_failed = False
try:
os.write(fd, test_code)
os.close(fd)
fd = None
proc = call_subprocess_Popen(
['g++', path, '-o', exe_path] + flags,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
if proc.returncode != 0:
# We were not able to compile
compilation_failed = True
else:
# Try to execute the program
try:
proc = call_subprocess_Popen([exe_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
detect_macos_sdot_bug.tested = True
if proc.returncode != 0:
# The bug is present
detect_macos_sdot_bug.present = True
else:
detect_macos_sdot_bug.present = False
finally:
os.remove(exe_path)
finally:
del dummy_stdin
try:
if fd is not None:
os.close(fd)
finally:
os.remove(path)
except OSError, e:
compilation_failed = True
# If compilation failed, we consider there is a bug,
# and the fix does not work
if compilation_failed:
detect_macos_sdot_bug.tested = True
detect_macos_sdot_bug.present = True
return True
if not detect_macos_sdot_bug.present:
return False
# Else, try a simple fix
test_fix_code = textwrap.dedent("""\
extern "C" float sdot_(int*, float*, int*, float*, int*);
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
return cblas_sdot(*Nx, x, *Sx, y, *Sy);
}
int main(int argc, char** argv)
{
int Nx = 5;
int Sx = 1;
float x[5] = {0, 1, 2, 3, 4};
float r = sdot_(&Nx, x, &Sx, x, &Sx);
if (fabs(r - 30.f) > 1e-6)
{
return -1;
}
return 0;
}
""")
try:
fd, path = tempfile.mkstemp(suffix='.c',
prefix='detect_macos_sdot_bug_testfix_')
exe_path = path[:-2]
dummy_stdin = open(os.devnull)
compilation_failed = False
try:
os.write(fd, test_fix_code)
os.close(fd)
fd = None
proc = call_subprocess_Popen(
['g++', path, '-o', exe_path] + flags,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
if proc.returncode != 0:
# We were not able to compile
compilation_failed = True
else:
# Try to execute the program
try:
proc = call_subprocess_Popen([exe_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
if proc.returncode == 0:
# The fix is working
detect_macos_sdot_bug.fix_works = True
finally:
os.remove(exe_path)
finally:
del dummy_stdin
try:
if fd is not None:
os.close(fd)
finally:
os.remove(path)
except OSError, e:
compilation_failed = True
return detect_macos_sdot_bug.present
detect_macos_sdot_bug.tested = False
detect_macos_sdot_bug.present = False
detect_macos_sdot_bug.fix_works = False
def cblas_header_text(): def cblas_header_text():
"""C header for the cblas interface.""" """C header for the cblas interface."""
...@@ -585,7 +778,7 @@ def cblas_header_text(): ...@@ -585,7 +778,7 @@ def cblas_header_text():
def blas_header_text(): def blas_header_text():
"""C header for the fortran blas interface""" """C header for the fortran blas interface"""
return """ header = """
extern "C" extern "C"
{ {
...@@ -789,6 +982,47 @@ def blas_header_text(): ...@@ -789,6 +982,47 @@ def blas_header_text():
} }
""" """
if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works:
header += textwrap.dedent("""\
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
return cblas_sdot(*Nx, x, *Sx, y, *Sy);
}
""")
else:
# Make sure the buggy version of sdot_ is never used
header += textwrap.dedent("""\
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
fprintf(stderr,
"FATAL: The implementation of BLAS SDOT "
"routine in your system has a bug that "
"makes it return wrong results.\\n"
"Please contact theano-dev@groups.google.com.\\n"
"You can work around this bug by using a "
"different BLAS library, or disabling BLAS\\n");
assert(0);
}
""")
return header
def blas_header_version():
# Version for the base header
version = (1,)
if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works:
# Version with fix
version += (1,)
else:
# Version with error
version += (2,)
return version
def ____gemm_code(check_ab, a_init, b_init): def ____gemm_code(check_ab, a_init, b_init):
mod = '%' mod = '%'
return """ return """
......
...@@ -51,7 +51,7 @@ class Conv3D(theano.Op): ...@@ -51,7 +51,7 @@ class Conv3D(theano.Op):
return "Conv3D" return "Conv3D"
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (3, blas_header_text.version)
def make_node(self, V, W, b, d): def make_node(self, V, W, b, d):
......
...@@ -965,7 +965,7 @@ class ConvOp(OpenMPOp): ...@@ -965,7 +965,7 @@ class ConvOp(OpenMPOp):
return ['<numpy/noprefix.h>', '<iostream>', '<sstream>'] return ['<numpy/noprefix.h>', '<iostream>', '<sstream>']
def c_code_cache_version(self): def c_code_cache_version(self):
return (10, self.openmp) return (10, self.openmp, blas_header_version())
def c_support_code(self): def c_support_code(self):
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论