提交 9a32adb3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Preempt missing CVM and fall back to non-C VM in VMLinker.make_vm

上级 bdb05333
......@@ -13,7 +13,7 @@ from theano.gof.graph import Apply
from theano.gof.op import Op
from theano.ifelse import ifelse
from theano.link.c.exceptions import MissingGXX
from theano.link.vm import VMLinker
from theano.link.vm import Loop, VMLinker
class TestCallbacks:
......@@ -450,19 +450,39 @@ def test_no_recycling():
assert m1 is m2
def test_VMLinker_no_cxx():
@pytest.mark.skipif(
not theano.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_VMLinker_make_vm_cvm():
# We don't want this at module level, since CXX might not be present
from theano.link.c.cvm import CVM
a = tensor.scalar()
linker = VMLinker(allow_gc=False, use_cloop=True)
f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, CVM)
def test_VMLinker_make_vm_no_cvm():
from importlib import reload
from unittest.mock import patch
with config.change_flags(cxx=""):
# Make sure that GXX isn't present
with pytest.raises(MissingGXX):
import theano.link.c.cvm
reload(theano.link.c.cvm)
# Make sure that `cvm` module is missing
with patch.dict("sys.modules", {"theano.link.c.cvm": None}):
linker = VMLinker(allow_gc=False, use_cloop=True)
a = tensor.scalar()
linker = VMLinker(allow_gc=False, use_cloop=True)
with pytest.raises(ModuleNotFoundError):
_ = function([a], a, mode=Mode(optimizer=None, linker=linker))
import theano.link.c.cvm
f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, Loop)
......@@ -13,8 +13,9 @@ import warnings
from collections import defaultdict
from theano.configdefaults import config
from theano.gof import Constant, Variable
from theano.gof.graph import Constant, Variable
from theano.link.basic import Container, LocalLinker
from theano.link.c.exceptions import MissingGXX
from theano.link.utils import gc_helper, map_storage, raise_with_op
......@@ -888,6 +889,11 @@ class VMLinker(LocalLinker):
pre_call_clear = [storage_map[v] for v in self.no_recycling]
try:
from theano.link.c.cvm import CVM
except (MissingGXX, ImportError):
CVM = None
if (
self.callback is not None
or self.callback_input is not None
......@@ -920,9 +926,7 @@ class VMLinker(LocalLinker):
callback=self.callback,
callback_input=self.callback_input,
)
elif self.use_cloop:
from theano.link.c.cvm import CVM
elif self.use_cloop and CVM:
# create a map from nodes to ints and vars to ints
nodes_idx = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论