提交 807a39ce authored 作者: nouiz's avatar nouiz

Merge pull request #1250 from lamblin/fix_blas_macos

Try to detect bug in Mac OS X parallel BLAS
......@@ -1501,37 +1501,51 @@ class GCC_compiler(object):
return cxxflags
@staticmethod
def try_flags(flag_list):
'''
Try to compile a dummy file with these flags.
def try_compile_tmp(src_code, tmp_prefix='', flags=(), try_run=False):
"""Try to compile (and run) a test program.
Returns True if compilation was successful, False if there
were errors.
'''
This is useful in various occasions, to check if libraries
or compilers are behaving as expected.
If try_run is True, the src_code is assumed to be executable,
and will be run.
If try_run is False, returns the compilation status.
If try_run is True, returns a (compile_status, run_status) pair.
"""
if not theano.config.cxx:
return False
rval = True
flags = list(flags)
compilation_ok = True
try:
code = """
int main(int argc, char** argv)
{
return 0;
}
"""
fd, path = tempfile.mkstemp(suffix='.c', prefix='try_flags_')
fd, path = tempfile.mkstemp(suffix='.c', prefix=tmp_prefix)
exe_path = path[:-2]
dummy_stdin = open(os.devnull)
try:
os.write(fd, code)
os.write(fd, src_code)
os.close(fd)
fd = None
proc = call_subprocess_Popen(['g++', path] + flag_list,
proc = call_subprocess_Popen(
['g++', path, '-o', exe_path] + flags,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
if proc.returncode != 0:
rval = False
compilation_ok = False
elif try_run:
# Try to execute the program
run_ok = False
try:
proc = call_subprocess_Popen([exe_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
run_ok = (proc.returncode == 0)
finally:
os.remove(exe_path)
finally:
del dummy_stdin
try:
......@@ -1539,9 +1553,34 @@ class GCC_compiler(object):
os.close(fd)
finally:
os.remove(path)
except OSError, e:
rval = False
return rval
compilation_ok = False
if not try_run:
return compilation_ok
else:
return (compilation_ok, run_ok)
@staticmethod
def try_flags(flag_list):
'''
Try to compile a dummy file with these flags.
Returns True if compilation was successful, False if there
were errors.
'''
if not theano.config.cxx:
return False
code = """
int main(int argc, char** argv)
{
return 0;
}
"""
return GCC_compiler.try_compile_tmp(code, tmp_prefix='try_flags_',
flags=flag_list, try_run=False)
@staticmethod
def compile_str(module_name, src_code, location=None,
......
......@@ -12,18 +12,15 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
import logging
import os
import subprocess
import tempfile
import warnings
import theano
from theano import config
from theano.misc.windows import call_subprocess_Popen
import theano.gof.cc
from theano.gof import graph
from theano.gof import utils
from theano.gof.cmodule import GCC_compiler
from theano.gof.fg import FunctionGraph
......@@ -770,42 +767,22 @@ class OpenMPOp(Op):
@staticmethod
def test_gxx_support():
default_openmp = True
try:
code = """
#include <omp.h>
int main( int argc, const char* argv[] )
{
int res[10];
for(int i=0; i < 10; i++){
res[i] = i;
}
}
"""
fd, path = tempfile.mkstemp(suffix='.c', prefix='test_omp_')
dummy_stdin = open(os.devnull)
try:
os.write(fd, code)
os.close(fd)
fd = None
proc = call_subprocess_Popen(['g++', '-fopenmp', path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=dummy_stdin.fileno())
proc.wait()
if proc.returncode != 0:
default_openmp = False
finally:
del dummy_stdin
# Ensure `fd` is closed before we remove the temporary file.
try:
if fd is not None:
os.close(fd)
finally:
os.remove(path)
except OSError, e:
return False
code = """
#include <omp.h>
int main( int argc, const char* argv[] )
{
int res[10];
for(int i=0; i < 10; i++){
res[i] = i;
}
}
"""
default_openmp = GCC_compiler.try_compile_tmp(
code=code,
tmp_prefix='test_omp_',
flags=['-fopenmp'],
try_run=False)
return default_openmp
def update_self_openmp(self):
......
......@@ -689,7 +689,7 @@ class UsmmCscDense(gof.Op):
return rval
def c_code_cache_version(self):
return (1,)
return (1, blas.blas_header_version())
usmm_csc_dense = UsmmCscDense(inplace=False)
usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
......@@ -1519,7 +1519,7 @@ class SamplingDotCSR(gof.Op):
])
def c_code_cache_version(self):
return (2, )
return (2, blas.blas_header_version())
def c_support_code(self):
return blas.blas_header_text()
......
......@@ -145,6 +145,7 @@ from theano.gof.python25 import all, any
import theano.scalar
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas')
......@@ -768,7 +769,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '')
def build_gemm_version(self):
return (12,)
return (12, blas_header_version())
class Gemm(GemmRelated):
......
from theano import config
from theano.tensor.blas import ldflags, blas_header_text
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer
from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
......@@ -251,7 +251,7 @@ class CGer(BaseBLAS, Ger):
return code
def c_code_cache_version(self):
return (8,)
return (8, blas_header_version())
@local_optimizer([ger, ger_destructive])
......@@ -578,7 +578,7 @@ class CGemv(BaseBLAS, Gemv):
return code
def c_code_cache_version(self):
return (10,)
return (10, blas_header_version())
@local_optimizer([gemv_inplace, gemv_no_inplace])
......
""" Header text for the C and Fortran BLAS interfaces.
There is no standard name or location for this header, so we just insert it ourselves into the C code
There is no standard name or location for this header, so we just insert it
ourselves into the C code
"""
import logging
import textwrap
import sys
from theano import config
from theano.gof.cmodule import GCC_compiler
_logger = logging.getLogger('theano.tensor.blas')
#_logger.setLevel(logging.INFO)
def detect_macos_sdot_bug():
"""
Try to detect a bug in the default BLAS in MacOS.
The problem in Theano has been reported in gh-1240,
the underlying bug has been confirmed in
http://www.macresearch.org/lapackblas-fortran-106#comment-17227.
This function tries to compile code triggering that bug,
and, if necessary, an attempted fix.
Three attributes of this function will be set:
- detect_macos_sdot_bug.tested will be set to True
when this function is called.
- detect_macos_sdot_bug.present will be set to True if the bug is
detected. Its value is returned by the function
- detect_macos_sdot_bug.fix_works will be set to True if the fix was
attempted, and succeeded.
"""
_logger.debug('Starting detection of bug in Mac OS BLAS sdot_ routine')
if detect_macos_sdot_bug.tested:
return detect_macos_sdot_bug.present
if sys.platform != 'darwin' or not config.blas.ldflags:
_logger.info('Not Mac OS, no sdot_ bug')
detect_macos_sdot_bug.tested = True
return False
# This code will return -1 if the dot product did not return
# the right value (30.).
flags = config.blas.ldflags.split()
for f in flags:
# Library directories should also be added as rpath,
# so that they can be loaded even if the environment
# variable LD_LIBRARY_PATH does not contain them
if f.startswith('-L'):
flags.append('-Wl,-rpath,' + f[2:])
test_code = textwrap.dedent("""\
extern "C" float sdot_(int*, float*, int*, float*, int*);
int main(int argc, char** argv)
{
int Nx = 5;
int Sx = 1;
float x[5] = {0, 1, 2, 3, 4};
float r = sdot_(&Nx, x, &Sx, x, &Sx);
if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6)
{
return -1;
}
return 0;
}
""")
_logger.debug('Trying to compile and run test case.')
compilation_ok, run_ok = GCC_compiler.try_compile_tmp(test_code,
tmp_prefix='detect_macos_sdot_bug_', flags=flags, try_run=True)
detect_macos_sdot_bug.tested = True
# If compilation failed, we consider there is a bug,
# and the fix does not work
if not compilation_ok:
_logger.info('Could not compile test case for sdot_.')
detect_macos_sdot_bug.present = True
return True
if run_ok:
_logger.info('The sdot_ bug is not present on this system.')
detect_macos_sdot_bug.present = False
return False
# Else, the bug is detected.
_logger.info('The sdot_ bug is present on this system.')
detect_macos_sdot_bug.present = True
# Then, try a simple fix
test_fix_code = textwrap.dedent("""\
extern "C" float sdot_(int*, float*, int*, float*, int*);
extern "C" float cblas_sdot(int, float*, int, float*, int);
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
return cblas_sdot(*Nx, x, *Sx, y, *Sy);
}
int main(int argc, char** argv)
{
int Nx = 5;
int Sx = 1;
float x[5] = {0, 1, 2, 3, 4};
float r = sdot_(&Nx, x, &Sx, x, &Sx);
if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6)
{
return -1;
}
return 0;
}
""")
_logger.debug('Trying to compile and run tentative workaround.')
compilation_fix_ok, run_fix_ok = GCC_compiler.try_compile_tmp(
test_fix_code,
tmp_prefix='detect_macos_sdot_bug_testfix_',
flags=flags,
try_run=True)
_logger.info("Status of tentative fix -- compilation OK: %s, works: %s",
compilation_fix_ok, run_fix_ok)
detect_macos_sdot_bug.fix_works = run_fix_ok
return detect_macos_sdot_bug.present
detect_macos_sdot_bug.tested = False
detect_macos_sdot_bug.present = False
detect_macos_sdot_bug.fix_works = False
def cblas_header_text():
"""C header for the cblas interface."""
......@@ -583,9 +713,10 @@ def cblas_header_text():
__END_DECLS
"""
def blas_header_text():
"""C header for the fortran blas interface"""
return """
header = """
extern "C"
{
......@@ -789,6 +920,48 @@ def blas_header_text():
}
"""
if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works:
header += textwrap.dedent("""\
extern "C" float cblas_sdot(int, float*, int, float*, int);
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
return cblas_sdot(*Nx, x, *Sx, y, *Sy);
}
""")
else:
# Make sure the buggy version of sdot_ is never used
header += textwrap.dedent("""\
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
fprintf(stderr,
"FATAL: The implementation of BLAS SDOT "
"routine in your system has a bug that "
"makes it return wrong results.\\n"
"Please contact theano-dev@groups.google.com.\\n"
"You can work around this bug by using a "
"different BLAS library, or disabling BLAS\\n");
assert(0);
}
""")
return header
def blas_header_version():
# Version for the base header
version = (1,)
if detect_macos_sdot_bug():
if detect_macos_sdot_bug.fix_works:
# Version with fix
version += (1,)
else:
# Version with error
version += (2,)
return version
def ____gemm_code(check_ab, a_init, b_init):
mod = '%'
return """
......@@ -929,31 +1102,3 @@ def ____gemm_code(check_ab, a_init, b_init):
/* v 1 */
""" % locals()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm = """
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
......@@ -51,7 +51,7 @@ class Conv3D(theano.Op):
return "Conv3D"
def c_code_cache_version(self):
return (3,)
return (3, blas_header_text.version)
def make_node(self, V, W, b, d):
......
......@@ -965,7 +965,7 @@ class ConvOp(OpenMPOp):
return ['<numpy/noprefix.h>', '<iostream>', '<sstream>']
def c_code_cache_version(self):
return (10, self.openmp)
return (10, self.openmp, blas.blas_header_version())
def c_support_code(self):
return """
......
......@@ -208,8 +208,12 @@ class TestCGemv(TestCase, TestOptimizationMixin):
def test_gemv1(self):
self.t_gemv1((3, 2))
self.t_gemv1((1, 2))
self.t_gemv1((0, 2))
self.t_gemv1((3, 1))
self.t_gemv1((3, 0))
self.t_gemv1((1, 0))
self.t_gemv1((0, 1))
self.t_gemv1((0, 0))
def test_gemv_dimensions(self, dtype='float32'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论