Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
807a39ce
提交
807a39ce
authored
2月 27, 2013
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1250 from lamblin/fix_blas_macos
Try to detect bug in Mac OS X parallel BLAS
上级
a788e179
f1b10057
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
255 行增加
和
89 行删除
+255
-89
cmodule.py
theano/gof/cmodule.py
+58
-19
op.py
theano/gof/op.py
+9
-32
opt.py
theano/sparse/opt.py
+2
-2
blas.py
theano/tensor/blas.py
+2
-1
blas_c.py
theano/tensor/blas_c.py
+3
-3
blas_headers.py
theano/tensor/blas_headers.py
+175
-30
Conv3D.py
theano/tensor/nnet/Conv3D.py
+1
-1
conv.py
theano/tensor/nnet/conv.py
+1
-1
test_blas_c.py
theano/tensor/tests/test_blas_c.py
+4
-0
没有找到文件。
theano/gof/cmodule.py
浏览文件 @
807a39ce
...
...
@@ -1501,37 +1501,51 @@ class GCC_compiler(object):
return
cxxflags
@staticmethod
def
try_flags
(
flag_list
):
'''
Try to compile a dummy file with these flags.
def
try_compile_tmp
(
src_code
,
tmp_prefix
=
''
,
flags
=
(),
try_run
=
False
):
"""Try to compile (and run) a test program.
Returns True if compilation was successful, False if there
were errors.
'''
This is useful in various occasions, to check if libraries
or compilers are behaving as expected.
If try_run is True, the src_code is assumed to be executable,
and will be run.
If try_run is False, returns the compilation status.
If try_run is True, returns a (compile_status, run_status) pair.
"""
if
not
theano
.
config
.
cxx
:
return
False
rval
=
True
flags
=
list
(
flags
)
compilation_ok
=
True
try
:
code
=
"""
int main(int argc, char** argv)
{
return 0;
}
"""
fd
,
path
=
tempfile
.
mkstemp
(
suffix
=
'.c'
,
prefix
=
'try_flags_'
)
fd
,
path
=
tempfile
.
mkstemp
(
suffix
=
'.c'
,
prefix
=
tmp_prefix
)
exe_path
=
path
[:
-
2
]
dummy_stdin
=
open
(
os
.
devnull
)
try
:
os
.
write
(
fd
,
code
)
os
.
write
(
fd
,
src_
code
)
os
.
close
(
fd
)
fd
=
None
proc
=
call_subprocess_Popen
([
'g++'
,
path
]
+
flag_list
,
proc
=
call_subprocess_Popen
(
[
'g++'
,
path
,
'-o'
,
exe_path
]
+
flags
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stdin
=
dummy_stdin
.
fileno
())
proc
.
wait
()
if
proc
.
returncode
!=
0
:
rval
=
False
compilation_ok
=
False
elif
try_run
:
# Try to execute the program
run_ok
=
False
try
:
proc
=
call_subprocess_Popen
([
exe_path
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stdin
=
dummy_stdin
.
fileno
())
proc
.
wait
()
run_ok
=
(
proc
.
returncode
==
0
)
finally
:
os
.
remove
(
exe_path
)
finally
:
del
dummy_stdin
try
:
...
...
@@ -1539,9 +1553,34 @@ class GCC_compiler(object):
os
.
close
(
fd
)
finally
:
os
.
remove
(
path
)
except
OSError
,
e
:
rval
=
False
return
rval
compilation_ok
=
False
if
not
try_run
:
return
compilation_ok
else
:
return
(
compilation_ok
,
run_ok
)
@staticmethod
def
try_flags
(
flag_list
):
'''
Try to compile a dummy file with these flags.
Returns True if compilation was successful, False if there
were errors.
'''
if
not
theano
.
config
.
cxx
:
return
False
code
=
"""
int main(int argc, char** argv)
{
return 0;
}
"""
return
GCC_compiler
.
try_compile_tmp
(
code
,
tmp_prefix
=
'try_flags_'
,
flags
=
flag_list
,
try_run
=
False
)
@staticmethod
def
compile_str
(
module_name
,
src_code
,
location
=
None
,
...
...
theano/gof/op.py
浏览文件 @
807a39ce
...
...
@@ -12,18 +12,15 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__
=
"restructuredtext en"
import
logging
import
os
import
subprocess
import
tempfile
import
warnings
import
theano
from
theano
import
config
from
theano.misc.windows
import
call_subprocess_Popen
import
theano.gof.cc
from
theano.gof
import
graph
from
theano.gof
import
utils
from
theano.gof.cmodule
import
GCC_compiler
from
theano.gof.fg
import
FunctionGraph
...
...
@@ -770,42 +767,22 @@ class OpenMPOp(Op):
@staticmethod
def
test_gxx_support
():
default_openmp
=
True
try
:
code
=
"""
#include <omp.h>
int main( int argc, const char* argv[] )
{
int main( int argc, const char* argv[] )
{
int res[10];
for(int i=0; i < 10; i++){
res[i] = i;
}
}
}
"""
fd
,
path
=
tempfile
.
mkstemp
(
suffix
=
'.c'
,
prefix
=
'test_omp_'
)
dummy_stdin
=
open
(
os
.
devnull
)
try
:
os
.
write
(
fd
,
code
)
os
.
close
(
fd
)
fd
=
None
proc
=
call_subprocess_Popen
([
'g++'
,
'-fopenmp'
,
path
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
stdin
=
dummy_stdin
.
fileno
())
proc
.
wait
()
if
proc
.
returncode
!=
0
:
default_openmp
=
False
finally
:
del
dummy_stdin
# Ensure `fd` is closed before we remove the temporary file.
try
:
if
fd
is
not
None
:
os
.
close
(
fd
)
finally
:
os
.
remove
(
path
)
except
OSError
,
e
:
return
False
default_openmp
=
GCC_compiler
.
try_compile_tmp
(
code
=
code
,
tmp_prefix
=
'test_omp_'
,
flags
=
[
'-fopenmp'
],
try_run
=
False
)
return
default_openmp
def
update_self_openmp
(
self
):
...
...
theano/sparse/opt.py
浏览文件 @
807a39ce
...
...
@@ -689,7 +689,7 @@ class UsmmCscDense(gof.Op):
return
rval
def
c_code_cache_version
(
self
):
return
(
1
,)
return
(
1
,
blas
.
blas_header_version
()
)
usmm_csc_dense
=
UsmmCscDense
(
inplace
=
False
)
usmm_csc_dense_inplace
=
UsmmCscDense
(
inplace
=
True
)
...
...
@@ -1519,7 +1519,7 @@ class SamplingDotCSR(gof.Op):
])
def
c_code_cache_version
(
self
):
return
(
2
,
)
return
(
2
,
blas
.
blas_header_version
()
)
def
c_support_code
(
self
):
return
blas
.
blas_header_text
()
...
...
theano/tensor/blas.py
浏览文件 @
807a39ce
...
...
@@ -145,6 +145,7 @@ from theano.gof.python25 import all, any
import
theano.scalar
from
theano.tensor
import
basic
as
T
from
theano.tensor.blas_headers
import
blas_header_text
from
theano.tensor.blas_headers
import
blas_header_version
from
theano.tensor.opt
import
local_dimshuffle_lift
_logger
=
logging
.
getLogger
(
'theano.tensor.blas'
)
...
...
@@ -768,7 +769,7 @@ class GemmRelated(Op):
self
.
end_switch_typenum
),
''
)
def
build_gemm_version
(
self
):
return
(
12
,)
return
(
12
,
blas_header_version
()
)
class
Gemm
(
GemmRelated
):
...
...
theano/tensor/blas_c.py
浏览文件 @
807a39ce
from
theano
import
config
from
theano.tensor.blas
import
ldflags
,
blas_header_text
from
theano.tensor.blas
import
ldflags
,
blas_header_text
,
blas_header_version
from
theano.tensor.blas
import
blas_optdb
,
optdb
,
local_optimizer
,
EquilibriumOptimizer
from
theano.tensor.blas
import
Ger
,
ger
,
ger_destructive
from
theano.tensor.blas
import
Gemv
,
gemv_inplace
,
gemv_no_inplace
...
...
@@ -251,7 +251,7 @@ class CGer(BaseBLAS, Ger):
return
code
def
c_code_cache_version
(
self
):
return
(
8
,)
return
(
8
,
blas_header_version
()
)
@local_optimizer
([
ger
,
ger_destructive
])
...
...
@@ -578,7 +578,7 @@ class CGemv(BaseBLAS, Gemv):
return
code
def
c_code_cache_version
(
self
):
return
(
10
,)
return
(
10
,
blas_header_version
()
)
@local_optimizer
([
gemv_inplace
,
gemv_no_inplace
])
...
...
theano/tensor/blas_headers.py
浏览文件 @
807a39ce
""" Header text for the C and Fortran BLAS interfaces.
There is no standard name or location for this header, so we just insert it ourselves into the C code
There is no standard name or location for this header, so we just insert it
ourselves into the C code
"""
import
logging
import
textwrap
import
sys
from
theano
import
config
from
theano.gof.cmodule
import
GCC_compiler
_logger
=
logging
.
getLogger
(
'theano.tensor.blas'
)
#_logger.setLevel(logging.INFO)
def
detect_macos_sdot_bug
():
"""
Try to detect a bug in the default BLAS in MacOS.
The problem in Theano has been reported in gh-1240,
the underlying bug has been confirmed in
http://www.macresearch.org/lapackblas-fortran-106#comment-17227.
This function tries to compile code triggering that bug,
and, if necessary, an attempted fix.
Three attributes of this function will be set:
- detect_macos_sdot_bug.tested will be set to True
when this function is called.
- detect_macos_sdot_bug.present will be set to True if the bug is
detected. Its value is returned by the function
- detect_macos_sdot_bug.fix_works will be set to True if the fix was
attempted, and succeeded.
"""
_logger
.
debug
(
'Starting detection of bug in Mac OS BLAS sdot_ routine'
)
if
detect_macos_sdot_bug
.
tested
:
return
detect_macos_sdot_bug
.
present
if
sys
.
platform
!=
'darwin'
or
not
config
.
blas
.
ldflags
:
_logger
.
info
(
'Not Mac OS, no sdot_ bug'
)
detect_macos_sdot_bug
.
tested
=
True
return
False
# This code will return -1 if the dot product did not return
# the right value (30.).
flags
=
config
.
blas
.
ldflags
.
split
()
for
f
in
flags
:
# Library directories should also be added as rpath,
# so that they can be loaded even if the environment
# variable LD_LIBRARY_PATH does not contain them
if
f
.
startswith
(
'-L'
):
flags
.
append
(
'-Wl,-rpath,'
+
f
[
2
:])
test_code
=
textwrap
.
dedent
(
"""
\
extern "C" float sdot_(int*, float*, int*, float*, int*);
int main(int argc, char** argv)
{
int Nx = 5;
int Sx = 1;
float x[5] = {0, 1, 2, 3, 4};
float r = sdot_(&Nx, x, &Sx, x, &Sx);
if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6)
{
return -1;
}
return 0;
}
"""
)
_logger
.
debug
(
'Trying to compile and run test case.'
)
compilation_ok
,
run_ok
=
GCC_compiler
.
try_compile_tmp
(
test_code
,
tmp_prefix
=
'detect_macos_sdot_bug_'
,
flags
=
flags
,
try_run
=
True
)
detect_macos_sdot_bug
.
tested
=
True
# If compilation failed, we consider there is a bug,
# and the fix does not work
if
not
compilation_ok
:
_logger
.
info
(
'Could not compile test case for sdot_.'
)
detect_macos_sdot_bug
.
present
=
True
return
True
if
run_ok
:
_logger
.
info
(
'The sdot_ bug is not present on this system.'
)
detect_macos_sdot_bug
.
present
=
False
return
False
# Else, the bug is detected.
_logger
.
info
(
'The sdot_ bug is present on this system.'
)
detect_macos_sdot_bug
.
present
=
True
# Then, try a simple fix
test_fix_code
=
textwrap
.
dedent
(
"""
\
extern "C" float sdot_(int*, float*, int*, float*, int*);
extern "C" float cblas_sdot(int, float*, int, float*, int);
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
return cblas_sdot(*Nx, x, *Sx, y, *Sy);
}
int main(int argc, char** argv)
{
int Nx = 5;
int Sx = 1;
float x[5] = {0, 1, 2, 3, 4};
float r = sdot_(&Nx, x, &Sx, x, &Sx);
if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6)
{
return -1;
}
return 0;
}
"""
)
_logger
.
debug
(
'Trying to compile and run tentative workaround.'
)
compilation_fix_ok
,
run_fix_ok
=
GCC_compiler
.
try_compile_tmp
(
test_fix_code
,
tmp_prefix
=
'detect_macos_sdot_bug_testfix_'
,
flags
=
flags
,
try_run
=
True
)
_logger
.
info
(
"Status of tentative fix -- compilation OK:
%
s, works:
%
s"
,
compilation_fix_ok
,
run_fix_ok
)
detect_macos_sdot_bug
.
fix_works
=
run_fix_ok
return
detect_macos_sdot_bug
.
present
detect_macos_sdot_bug
.
tested
=
False
detect_macos_sdot_bug
.
present
=
False
detect_macos_sdot_bug
.
fix_works
=
False
def
cblas_header_text
():
"""C header for the cblas interface."""
...
...
@@ -583,9 +713,10 @@ def cblas_header_text():
__END_DECLS
"""
def
blas_header_text
():
"""C header for the fortran blas interface"""
return
"""
header
=
"""
extern "C"
{
...
...
@@ -789,6 +920,48 @@ def blas_header_text():
}
"""
if
detect_macos_sdot_bug
():
if
detect_macos_sdot_bug
.
fix_works
:
header
+=
textwrap
.
dedent
(
"""
\
extern "C" float cblas_sdot(int, float*, int, float*, int);
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
return cblas_sdot(*Nx, x, *Sx, y, *Sy);
}
"""
)
else
:
# Make sure the buggy version of sdot_ is never used
header
+=
textwrap
.
dedent
(
"""
\
float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy)
{
fprintf(stderr,
"FATAL: The implementation of BLAS SDOT "
"routine in your system has a bug that "
"makes it return wrong results.
\\
n"
"Please contact theano-dev@groups.google.com.
\\
n"
"You can work around this bug by using a "
"different BLAS library, or disabling BLAS
\\
n");
assert(0);
}
"""
)
return
header
def
blas_header_version
():
# Version for the base header
version
=
(
1
,)
if
detect_macos_sdot_bug
():
if
detect_macos_sdot_bug
.
fix_works
:
# Version with fix
version
+=
(
1
,)
else
:
# Version with error
version
+=
(
2
,)
return
version
def
____gemm_code
(
check_ab
,
a_init
,
b_init
):
mod
=
'
%
'
return
"""
...
...
@@ -929,31 +1102,3 @@ def ____gemm_code(check_ab, a_init, b_init):
/* v 1 */
"""
%
locals
()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm
=
"""
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
theano/tensor/nnet/Conv3D.py
浏览文件 @
807a39ce
...
...
@@ -51,7 +51,7 @@ class Conv3D(theano.Op):
return
"Conv3D"
def
c_code_cache_version
(
self
):
return
(
3
,)
return
(
3
,
blas_header_text
.
version
)
def
make_node
(
self
,
V
,
W
,
b
,
d
):
...
...
theano/tensor/nnet/conv.py
浏览文件 @
807a39ce
...
...
@@ -965,7 +965,7 @@ class ConvOp(OpenMPOp):
return
[
'<numpy/noprefix.h>'
,
'<iostream>'
,
'<sstream>'
]
def
c_code_cache_version
(
self
):
return
(
10
,
self
.
openmp
)
return
(
10
,
self
.
openmp
,
blas
.
blas_header_version
()
)
def
c_support_code
(
self
):
return
"""
...
...
theano/tensor/tests/test_blas_c.py
浏览文件 @
807a39ce
...
...
@@ -208,8 +208,12 @@ class TestCGemv(TestCase, TestOptimizationMixin):
def
test_gemv1
(
self
):
self
.
t_gemv1
((
3
,
2
))
self
.
t_gemv1
((
1
,
2
))
self
.
t_gemv1
((
0
,
2
))
self
.
t_gemv1
((
3
,
1
))
self
.
t_gemv1
((
3
,
0
))
self
.
t_gemv1
((
1
,
0
))
self
.
t_gemv1
((
0
,
1
))
self
.
t_gemv1
((
0
,
0
))
def
test_gemv_dimensions
(
self
,
dtype
=
'float32'
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论