提交 9b3f6351 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6384 from notoraptor/fix-version

Fix pygpu version detection.
...@@ -20,7 +20,6 @@ pygpu_activated = False ...@@ -20,7 +20,6 @@ pygpu_activated = False
try: try:
import pygpu import pygpu
import pygpu.gpuarray import pygpu.gpuarray
import pygpu.version
except ImportError: except ImportError:
pygpu = None pygpu = None
...@@ -42,23 +41,36 @@ def transfer(x, target): ...@@ -42,23 +41,36 @@ def transfer(x, target):
register_transfer(transfer) register_transfer(transfer)
def pygpu_parse_version(version_string):
from collections import namedtuple
version_type = namedtuple('version_type', ('major', 'minor', 'patch', 'fullversion'))
pieces = version_string.split('.', 2)
assert len(pieces) == 3
major = int(pieces[0])
minor = int(pieces[1])
patch = int(pieces[2].split('+', 1)[0])
fullversion = '%d.%d.%s' % (major, minor, pieces[2])
return version_type(major=major, minor=minor, patch=patch, fullversion=fullversion)
def init_dev(dev, name=None, preallocate=None): def init_dev(dev, name=None, preallocate=None):
global pygpu_activated global pygpu_activated
if not config.cxx: if not config.cxx:
raise RuntimeError("The new gpu-backend need a c++ compiler.") raise RuntimeError("The new gpu-backend need a c++ compiler.")
if (pygpu.version.major != 0 or pygpu.version.minor != 7 or pygpu_version = pygpu_parse_version(pygpu.__version__)
pygpu.version.patch < 0): if (pygpu_version.major != 0 or pygpu_version.minor != 7 or
pygpu_version.patch < 0):
raise ValueError( raise ValueError(
"Your installed version of pygpu(%s) is too old, please upgrade to 0.7.0 or later" % "Your installed version of pygpu(%s) is too old, please upgrade to 0.7.0 or later" %
pygpu.version.fullversion) pygpu_version.fullversion)
# This is for the C headers API, we need to match the exact version. # This is for the C headers API, we need to match the exact version.
gpuarray_version_major_supported = 2 gpuarray_version_major_supported = 2
gpuarray_version_major_detected = pygpu.gpuarray.api_version()[0] gpuarray_version_major_detected = pygpu.gpuarray.api_version()[0]
if gpuarray_version_major_detected != gpuarray_version_major_supported: if gpuarray_version_major_detected != gpuarray_version_major_supported:
raise ValueError( raise ValueError(
"Your installed version oflibgpuarray is not in sync with the current Theano" "Your installed version of libgpuarray is not in sync with the current Theano"
" version. The installed libgpuarray version support API version %d," " version. The installed libgpuarray version supports API version %d,"
" while current Theano support API version %d. Change the version of" " while current Theano supports API version %d. Change the version of"
" libgpuarray or Theano to fix this problem.", " libgpuarray or Theano to fix this problem.",
gpuarray_version_major_detected, gpuarray_version_major_detected,
gpuarray_version_major_supported) gpuarray_version_major_supported)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论