提交 34736410 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Towards Windows compatibility: .pyd instead of .so, and Python library linking fixed

上级 af7f0cd8
...@@ -124,7 +124,7 @@ class DynamicModule(object): ...@@ -124,7 +124,7 @@ class DynamicModule(object):
#TODO: add_type #TODO: add_type
def dlimport(fullpath, suffix=None): def dlimport(fullpath, suffix=None):
"""Dynamically load a .so, .dll, or .py file """Dynamically load a .so, .pyd, .dll, or .py file
:type fullpath: string :type fullpath: string
:param fullpath: a fully-qualified path do a compiled python module :param fullpath: a fully-qualified path do a compiled python module
...@@ -137,6 +137,8 @@ def dlimport(fullpath, suffix=None): ...@@ -137,6 +137,8 @@ def dlimport(fullpath, suffix=None):
if suffix is None: if suffix is None:
if fullpath.endswith('.so'): if fullpath.endswith('.so'):
suffix = '.so' suffix = '.so'
elif fullpath.endswith('.pyd'):
suffix = '.pyd'
elif fullpath.endswith('.dll'): elif fullpath.endswith('.dll'):
suffix = '.dll' suffix = '.dll'
elif fullpath.endswith('.py'): elif fullpath.endswith('.py'):
...@@ -172,9 +174,10 @@ def module_name_from_dir(dirname): ...@@ -172,9 +174,10 @@ def module_name_from_dir(dirname):
"""Scan the contents of a cache directory and return full path of the dynamic lib in it. """Scan the contents of a cache directory and return full path of the dynamic lib in it.
""" """
files = os.listdir(dirname) files = os.listdir(dirname)
names = [file for file in files if file.endswith('.so')] names = [file for file in files if file.endswith('.so') or
file.endswith('.pyd')]
if len(names) != 1: if len(names) != 1:
raise Exception('Failed to load .so from dir', dirname) raise Exception('Failed to load dynamic libraries from dir', dirname)
return os.path.join(dirname, names[0]) return os.path.join(dirname, names[0])
class ModuleCache(object): class ModuleCache(object):
...@@ -210,7 +213,7 @@ class ModuleCache(object): ...@@ -210,7 +213,7 @@ class ModuleCache(object):
"""maps module names to loaded module objects""" """maps module names to loaded module objects"""
entry_from_key = {} entry_from_key = {}
"""Maps keys to the filename of a .so """Maps keys to the filename of a .so/.pyd.
""" """
stats = [] stats = []
...@@ -345,6 +348,12 @@ class ModuleCache(object): ...@@ -345,6 +348,12 @@ class ModuleCache(object):
key_pkl = os.path.join(location, 'key.pkl') key_pkl = os.path.join(location, 'key.pkl')
key_file = file(key_pkl, 'w') key_file = file(key_pkl, 'w')
try: try:
if sys.platform == 'win32':
# Looks like there is a bug under Windows, where using the
# highest protocol will result in a pickle file that cannot
# be loaded afterwards.
cPickle.dump(key, key_file)
else:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL) cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL)
key_file.close() key_file.close()
key_broken = False key_broken = False
...@@ -441,6 +450,13 @@ def get_module_cache(dirname, force_fresh=None): ...@@ -441,6 +450,13 @@ def get_module_cache(dirname, force_fresh=None):
atexit.register(_module_cache._on_atexit) atexit.register(_module_cache._on_atexit)
return _module_cache return _module_cache
def get_lib_extension():
"""Return the platform-dependent extension for compiled modules."""
if sys.platform == 'win32':
return 'pyd'
else:
return 'so'
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None): preargs=[], tmpdir=None):
#TODO: don't to the dlimport in this function #TODO: don't to the dlimport in this function
...@@ -451,7 +467,18 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -451,7 +467,18 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
include_dirs = [distutils.sysconfig.get_python_inc()] + \ include_dirs = [distutils.sysconfig.get_python_inc()] + \
numpy.distutils.misc_util.get_numpy_include_dirs()\ numpy.distutils.misc_util.get_numpy_include_dirs()\
+ include_dirs + include_dirs
libs = [os.path.split(distutils.sysconfig.get_python_inc())[-1]] + libs python_inc = distutils.sysconfig.get_python_inc()
if sys.platform == 'win32':
# Typical include directory: C:\Python26\include
libname = os.path.basename(os.path.dirname(python_inc)).lower()
# Also add directory containing the Python library to the library
# directories.
python_lib_dir = os.path.join(os.path.dirname(python_inc), 'libs')
lib_dirs = [python_lib_dir] + lib_dirs
else:
# Typical include directory: /usr/include/python2.6
libname = os.path.basename(python_inc)
libs = [libname] + libs
workdir = location workdir = location
...@@ -465,7 +492,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -465,7 +492,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
cppfile.write(src_code) cppfile.write(src_code)
cppfile.close() cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name) lib_filename = os.path.join(workdir, '%s.%s' %
(module_name, get_lib_extension()))
debug('Generating shared lib', lib_filename) debug('Generating shared lib', lib_filename)
cmd = ['g++', '-shared', '-g'] cmd = ['g++', '-shared', '-g']
...@@ -490,7 +518,6 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -490,7 +518,6 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
file(os.path.join(workdir, "__init__.py"),'w').close() file(os.path.join(workdir, "__init__.py"),'w').close()
rval = dlimport(lib_filename) rval = dlimport(lib_filename)
return rval return rval
...@@ -522,7 +549,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -522,7 +549,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
try: try:
cppfile.write(src_code) cppfile.write(src_code)
cppfile.close() cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name) lib_filename = os.path.join(workdir, '%s.%s' %
(module_name, get_lib_extension()))
debug('Generating shared lib', lib_filename) debug('Generating shared lib', lib_filename)
cmd = ['nvcc', '-shared', '-g'] cmd = ['nvcc', '-shared', '-g']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论