Unverified 提交 77daa7eb authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6574 from GSam/pypy-hack

Make Theano work under pypy (WIP)
...@@ -269,16 +269,23 @@ def dlimport(fullpath, suffix=None): ...@@ -269,16 +269,23 @@ def dlimport(fullpath, suffix=None):
if not os.path.isabs(fullpath): if not os.path.isabs(fullpath):
raise ValueError('`fullpath` must be an absolute path', fullpath) raise ValueError('`fullpath` must be an absolute path', fullpath)
if suffix is None: if suffix is None:
if fullpath.endswith('.so'): suffix = ''
suffix = '.so'
elif fullpath.endswith('.pyd'): dist_suffix = distutils.sysconfig.get_config_var("SO")
suffix = '.pyd' if dist_suffix is not None and dist_suffix != '':
elif fullpath.endswith('.dll'): if fullpath.endswith(dist_suffix):
suffix = '.dll' suffix = dist_suffix
elif fullpath.endswith('.py'):
suffix = '.py' if suffix == '':
else: if fullpath.endswith('.so'):
suffix = '' suffix = '.so'
elif fullpath.endswith('.pyd'):
suffix = '.pyd'
elif fullpath.endswith('.dll'):
suffix = '.dll'
elif fullpath.endswith('.py'):
suffix = '.py'
rval = None rval = None
if fullpath.endswith(suffix): if fullpath.endswith(suffix):
module_name = '.'.join(fullpath.split(os.path.sep)[-2:])[:-len(suffix)] module_name = '.'.join(fullpath.split(os.path.sep)[-2:])[:-len(suffix)]
...@@ -1675,22 +1682,33 @@ def std_lib_dirs_and_libs(): ...@@ -1675,22 +1682,33 @@ def std_lib_dirs_and_libs():
elif sys.platform == 'darwin': elif sys.platform == 'darwin':
std_lib_dirs_and_libs.data = [], [] std_lib_dirs_and_libs.data = [], []
else: else:
# assume Linux if platform.python_implementation() == 'PyPy':
# Typical include directory: /usr/include/python2.6 # Assume Linux (note: Ubuntu doesn't ship this .so)
if sys.version_info < (3,):
libname = "pypy-c"
else:
libname = "pypy3-c"
# Unfortunately the only convention of this .so is that it appears
# next to the location of the interpreter binary.
libdir = os.path.dirname(os.path.realpath(sys.executable))
else:
# Assume Linux
# Typical include directory: /usr/include/python2.6
# get the name of the python library (shared object)
# get the name of the python library (shared object) libname = distutils.sysconfig.get_config_var("LDLIBRARY")
libname = distutils.sysconfig.get_config_var("LDLIBRARY")
if libname.startswith("lib"): if libname.startswith("lib"):
libname = libname[3:] libname = libname[3:]
# remove extension if present # remove extension if present
if libname.endswith(".so"): if libname.endswith(".so"):
libname = libname[:-3] libname = libname[:-3]
elif libname.endswith(".a"): elif libname.endswith(".a"):
libname = libname[:-2] libname = libname[:-2]
libdir = distutils.sysconfig.get_config_var("LIBDIR") libdir = distutils.sysconfig.get_config_var("LIBDIR")
std_lib_dirs_and_libs.data = [libname], [libdir] std_lib_dirs_and_libs.data = [libname], [libdir]
...@@ -2277,9 +2295,18 @@ class GCC_compiler(Compiler): ...@@ -2277,9 +2295,18 @@ class GCC_compiler(Compiler):
if not src_code.endswith('\n'): if not src_code.endswith('\n'):
cppfile.write('\n') cppfile.write('\n')
lib_filename = os.path.join( if platform.python_implementation() == 'PyPy':
location, suffix = '.' + get_lib_extension()
'%s.%s' % (module_name, get_lib_extension()))
dist_suffix = distutils.sysconfig.get_config_var("SO")
if dist_suffix is not None and dist_suffix != '':
suffix = dist_suffix
filepath = '%s%s' % (module_name, suffix)
else:
filepath = '%s.%s' % (module_name, get_lib_extension())
lib_filename = os.path.join(location, filepath)
_logger.debug('Generating shared lib %s', lib_filename) _logger.debug('Generating shared lib %s', lib_filename)
cmd = [theano.config.cxx, get_gcc_shared_library_arg(), '-g'] cmd = [theano.config.cxx, get_gcc_shared_library_arg(), '-g']
......
...@@ -7,6 +7,7 @@ Defines the `Type` class. ...@@ -7,6 +7,7 @@ Defines the `Type` class.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import ctypes import ctypes
import platform
from six import string_types from six import string_types
...@@ -606,8 +607,11 @@ class Generic(SingletonType): ...@@ -606,8 +607,11 @@ class Generic(SingletonType):
generic = Generic() generic = Generic()
_cdata_type = ctypes.py_object.from_address( _cdata_type = None
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
if platform.python_implementation() != 'PyPy':
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
class _make_cdata(Op): class _make_cdata(Op):
...@@ -679,8 +683,12 @@ class CDataType(Type): ...@@ -679,8 +683,12 @@ class CDataType(Type):
self.version = version self.version = version
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if data is not None and not isinstance(data, _cdata_type): # We ignore this type-check (_cdata_type is None) in PyPy
raise TypeError("expected None or a PyCapsule") # because this type is not exposed to us.
if data is not None and _cdata_type is not None:
if not isinstance(data, _cdata_type):
raise TypeError("expected None or a PyCapsule")
return data return data
def _get_func(self): def _get_func(self):
......
...@@ -13,6 +13,7 @@ import logging ...@@ -13,6 +13,7 @@ import logging
import sys import sys
import time import time
import warnings import warnings
import platform
from theano.configparser import (config, _config_var_list) from theano.configparser import (config, _config_var_list)
...@@ -983,7 +984,11 @@ class VM_Linker(link.LocalLinker): ...@@ -983,7 +984,11 @@ class VM_Linker(link.LocalLinker):
if oidx in update_in_from_out: if oidx in update_in_from_out:
update_storage.append(update_in_from_out[oidx]) update_storage.append(update_in_from_out[oidx])
c0 = sys.getrefcount(node_n_inputs) # PyPy has no sys.getrefcount, so ignore this check if not running
# under CPython.
if platform.python_implementation() == 'CPython':
c0 = sys.getrefcount(node_n_inputs)
vm = CVM( vm = CVM(
nodes, nodes,
thunks, thunks,
...@@ -1006,7 +1011,9 @@ class VM_Linker(link.LocalLinker): ...@@ -1006,7 +1011,9 @@ class VM_Linker(link.LocalLinker):
update_storage=update_storage, update_storage=update_storage,
dependencies=dependency_map_list, dependencies=dependency_map_list,
) )
assert c0 == sys.getrefcount(node_n_inputs)
if platform.python_implementation() == 'CPython':
assert c0 == sys.getrefcount(node_n_inputs)
else: else:
lazy = self.lazy lazy = self.lazy
if lazy is None: if lazy is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论