Unverified 提交 77daa7eb authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6574 from GSam/pypy-hack

Make Theano work under pypy (WIP)
......@@ -269,16 +269,23 @@ def dlimport(fullpath, suffix=None):
if not os.path.isabs(fullpath):
raise ValueError('`fullpath` must be an absolute path', fullpath)
if suffix is None:
if fullpath.endswith('.so'):
suffix = '.so'
elif fullpath.endswith('.pyd'):
suffix = '.pyd'
elif fullpath.endswith('.dll'):
suffix = '.dll'
elif fullpath.endswith('.py'):
suffix = '.py'
else:
suffix = ''
suffix = ''
dist_suffix = distutils.sysconfig.get_config_var("SO")
if dist_suffix is not None and dist_suffix != '':
if fullpath.endswith(dist_suffix):
suffix = dist_suffix
if suffix == '':
if fullpath.endswith('.so'):
suffix = '.so'
elif fullpath.endswith('.pyd'):
suffix = '.pyd'
elif fullpath.endswith('.dll'):
suffix = '.dll'
elif fullpath.endswith('.py'):
suffix = '.py'
rval = None
if fullpath.endswith(suffix):
module_name = '.'.join(fullpath.split(os.path.sep)[-2:])[:-len(suffix)]
......@@ -1675,22 +1682,33 @@ def std_lib_dirs_and_libs():
elif sys.platform == 'darwin':
std_lib_dirs_and_libs.data = [], []
else:
# assume Linux
# Typical include directory: /usr/include/python2.6
if platform.python_implementation() == 'PyPy':
# Assume Linux (note: Ubuntu doesn't ship this .so)
if sys.version_info < (3,):
libname = "pypy-c"
else:
libname = "pypy3-c"
# Unfortunately the only convention of this .so is that it appears
# next to the location of the interpreter binary.
libdir = os.path.dirname(os.path.realpath(sys.executable))
else:
# Assume Linux
# Typical include directory: /usr/include/python2.6
# get the name of the python library (shared object)
# get the name of the python library (shared object)
libname = distutils.sysconfig.get_config_var("LDLIBRARY")
libname = distutils.sysconfig.get_config_var("LDLIBRARY")
if libname.startswith("lib"):
libname = libname[3:]
if libname.startswith("lib"):
libname = libname[3:]
# remove extension if present
if libname.endswith(".so"):
libname = libname[:-3]
elif libname.endswith(".a"):
libname = libname[:-2]
# remove extension if present
if libname.endswith(".so"):
libname = libname[:-3]
elif libname.endswith(".a"):
libname = libname[:-2]
libdir = distutils.sysconfig.get_config_var("LIBDIR")
libdir = distutils.sysconfig.get_config_var("LIBDIR")
std_lib_dirs_and_libs.data = [libname], [libdir]
......@@ -2277,9 +2295,18 @@ class GCC_compiler(Compiler):
if not src_code.endswith('\n'):
cppfile.write('\n')
lib_filename = os.path.join(
location,
'%s.%s' % (module_name, get_lib_extension()))
if platform.python_implementation() == 'PyPy':
suffix = '.' + get_lib_extension()
dist_suffix = distutils.sysconfig.get_config_var("SO")
if dist_suffix is not None and dist_suffix != '':
suffix = dist_suffix
filepath = '%s%s' % (module_name, suffix)
else:
filepath = '%s.%s' % (module_name, get_lib_extension())
lib_filename = os.path.join(location, filepath)
_logger.debug('Generating shared lib %s', lib_filename)
cmd = [theano.config.cxx, get_gcc_shared_library_arg(), '-g']
......
......@@ -7,6 +7,7 @@ Defines the `Type` class.
from __future__ import absolute_import, print_function, division
import ctypes
import platform
from six import string_types
......@@ -606,8 +607,11 @@ class Generic(SingletonType):
generic = Generic()
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
_cdata_type = None
if platform.python_implementation() != 'PyPy':
_cdata_type = ctypes.py_object.from_address(
ctypes.addressof(ctypes.pythonapi.PyCapsule_Type)).value
class _make_cdata(Op):
......@@ -679,8 +683,12 @@ class CDataType(Type):
self.version = version
def filter(self, data, strict=False, allow_downcast=None):
if data is not None and not isinstance(data, _cdata_type):
raise TypeError("expected None or a PyCapsule")
# We ignore this type-check (_cdata_type is None) in PyPy
# because this type is not exposed to us.
if data is not None and _cdata_type is not None:
if not isinstance(data, _cdata_type):
raise TypeError("expected None or a PyCapsule")
return data
def _get_func(self):
......
......@@ -13,6 +13,7 @@ import logging
import sys
import time
import warnings
import platform
from theano.configparser import (config, _config_var_list)
......@@ -983,7 +984,11 @@ class VM_Linker(link.LocalLinker):
if oidx in update_in_from_out:
update_storage.append(update_in_from_out[oidx])
c0 = sys.getrefcount(node_n_inputs)
# PyPy has no sys.getrefcount, so ignore this check if not running
# under CPython.
if platform.python_implementation() == 'CPython':
c0 = sys.getrefcount(node_n_inputs)
vm = CVM(
nodes,
thunks,
......@@ -1006,7 +1011,9 @@ class VM_Linker(link.LocalLinker):
update_storage=update_storage,
dependencies=dependency_map_list,
)
assert c0 == sys.getrefcount(node_n_inputs)
if platform.python_implementation() == 'CPython':
assert c0 == sys.getrefcount(node_n_inputs)
else:
lazy = self.lazy
if lazy is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论