提交 cb05c161 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Rename theano.link.c.vm to theano.link.vm

The `CVM` class definition was also split off into `theano.link.c.cvm`.
上级 8049c3ed
...@@ -124,7 +124,7 @@ def test_ifelse(): ...@@ -124,7 +124,7 @@ def test_ifelse():
for cloop in cloops: for cloop in cloops:
for lazy in lazys: for lazy in lazys:
linker = theano.link.c.vm.VMLinker(use_cloop=cloop, lazy=lazy) linker = theano.link.vm.VMLinker(use_cloop=cloop, lazy=lazy)
f = function( f = function(
[a, b, c], [a, b, c],
ifelse(a, notimpl(b), c), ifelse(a, notimpl(b), c),
...@@ -154,11 +154,11 @@ def test_nested(): ...@@ -154,11 +154,11 @@ def test_nested():
t4 = ifelseifelseif(tt.eq(x1, x2), x1, tt.eq(x1, 5), x2, c2, t3, t3 + 0.5) t4 = ifelseifelseif(tt.eq(x1, x2), x1, tt.eq(x1, 5), x2, c2, t3, t3 + 0.5)
t4.name = "t4" t4.name = "t4"
linker = theano.link.c.vm.VMLinker(lazy=False) linker = theano.link.vm.VMLinker(lazy=False)
f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run")) f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
with pytest.raises(NotImplementedOpException): with pytest.raises(NotImplementedOpException):
f(1, 0, np.array(10, dtype=x1.dtype), 0) f(1, 0, np.array(10, dtype=x1.dtype), 0)
linker = theano.link.c.vm.VMLinker(lazy=True) linker = theano.link.vm.VMLinker(lazy=True)
f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run")) f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5 assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5
...@@ -6,11 +6,11 @@ import numpy as np ...@@ -6,11 +6,11 @@ import numpy as np
import pytest import pytest
import theano import theano
from theano import function, tensor from theano import config, function, tensor
from theano.compile import Mode from theano.compile import Mode
from theano.ifelse import ifelse from theano.ifelse import ifelse
from theano.link.c.cc import OpWiseCLinker from theano.link.c.exceptions import MissingGXX
from theano.link.c.vm import VMLinker from theano.link.vm import VMLinker
class TestCallbacks: class TestCallbacks:
...@@ -128,6 +128,8 @@ def test_speed(): ...@@ -128,6 +128,8 @@ def test_speed():
print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop")
from theano.link.c.cc import OpWiseCLinker
time_linker("c|py", OpWiseCLinker) time_linker("c|py", OpWiseCLinker)
time_linker("vmLinker", VMLinker) time_linker("vmLinker", VMLinker)
time_linker("vmLinker_nogc", lambda: VMLinker(allow_gc=False)) time_linker("vmLinker_nogc", lambda: VMLinker(allow_gc=False))
...@@ -443,3 +445,21 @@ def test_no_recycling(): ...@@ -443,3 +445,21 @@ def test_no_recycling():
m1 = f.fn.thunks[0].thunk.module m1 = f.fn.thunks[0].thunk.module
m2 = f2.fn.thunks[0].thunk.module m2 = f2.fn.thunks[0].thunk.module
assert m1 is m2 assert m1 is m2
def test_VMLinker_no_cxx():
from importlib import reload
from unittest.mock import patch
with config.change_flags(cxx=""):
with pytest.raises(MissingGXX):
import theano.link.c.cvm
reload(theano.link.c.cvm)
with patch.dict("sys.modules", {"theano.link.c.cvm": None}):
linker = VMLinker(allow_gc=False, use_cloop=True)
a = tensor.scalar()
with pytest.raises(ModuleNotFoundError):
_ = function([a], a, mode=Mode(optimizer=None, linker=linker))
from theano.compile import Mode from theano.compile import Mode
from theano.configdefaults import config from theano.configdefaults import config
from theano.link.basic import WrapLinkerMany from theano.link.basic import WrapLinkerMany
from theano.link.c.vm import VMLinker from theano.link.vm import VMLinker
from theano.printing import hex_digest, min_informative_str from theano.printing import hex_digest, min_informative_str
......
...@@ -607,7 +607,7 @@ class TestConv2D(utt.InferShapeTester): ...@@ -607,7 +607,7 @@ class TestConv2D(utt.InferShapeTester):
openmp=openmp, openmp=openmp,
) )
mode = theano.Mode( mode = theano.Mode(
linker=theano.link.c.vm.VMLinker( linker=theano.link.vm.VMLinker(
allow_gc=False, use_cloop=True allow_gc=False, use_cloop=True
) )
) )
......
...@@ -11,8 +11,8 @@ from theano import config, gof ...@@ -11,8 +11,8 @@ from theano import config, gof
from theano.compile.function.types import Supervisor from theano.compile.function.types import Supervisor
from theano.link.basic import PerformLinker from theano.link.basic import PerformLinker
from theano.link.c.cc import CLinker, OpWiseCLinker from theano.link.c.cc import CLinker, OpWiseCLinker
from theano.link.c.vm import VMLinker
from theano.link.jax import JAXLinker from theano.link.jax import JAXLinker
from theano.link.vm import VMLinker
_logger = logging.getLogger("theano.compile.mode") _logger = logging.getLogger("theano.compile.mode")
...@@ -409,7 +409,7 @@ class Mode: ...@@ -409,7 +409,7 @@ class Mode:
# string as the key # string as the key
# Use VM_linker to allow lazy evaluation by default. # Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode( FAST_COMPILE = Mode(
theano.link.c.vm.VMLinker(use_cloop=False, c_thunks=False), "fast_compile" theano.link.vm.VMLinker(use_cloop=False, c_thunks=False), "fast_compile"
) )
if theano.config.cxx: if theano.config.cxx:
FAST_RUN = Mode("cvm", "fast_run") FAST_RUN = Mode("cvm", "fast_run")
......
...@@ -292,7 +292,7 @@ class NanGuardMode(Mode): ...@@ -292,7 +292,7 @@ class NanGuardMode(Mode):
if getattr(var.tag, "nan_guard_mode_check", True): if getattr(var.tag, "nan_guard_mode_check", True):
do_check_on(value, None, var=var) do_check_on(value, None, var=var)
wrap_linker = theano.link.c.vm.VMLinker( wrap_linker = theano.link.vm.VMLinker(
callback=nan_check, callback_input=nan_check_input callback=nan_check, callback_input=nan_check_input
) )
super().__init__(wrap_linker, optimizer=self.provided_optimizer) super().__init__(wrap_linker, optimizer=self.provided_optimizer)
from theano import config
from theano.link.c.exceptions import MissingGXX
from theano.link.vm import VM
try:
# If cxx is explicitly set to an empty string, we do not want to import
# either lazy-linker C code or lazy-linker compiled C code from the cache.
if not config.cxx:
raise MissingGXX(
"lazylinker will not be imported if theano.config.cxx is not set."
)
from theano.link.c import lazylinker_c
class CVM(lazylinker_c.CLazyLinker, VM):
def __init__(self, fgraph, *args, **kwargs):
self.fgraph = fgraph
lazylinker_c.CLazyLinker.__init__(self, *args, **kwargs)
# skip VM.__init__
except ImportError:
pass
except (OSError, MissingGXX):
# OSError happens when g++ is not installed. In that case, we
# already changed the default linker to something else then CVM.
# Currently this is the py linker.
# Here we assert that the default linker is not cvm.
if config._config_var_dict["linker"].default.startswith("cvm"):
raise
...@@ -15,7 +15,6 @@ from collections import defaultdict ...@@ -15,7 +15,6 @@ from collections import defaultdict
from theano import config from theano import config
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.link.basic import Container, LocalLinker from theano.link.basic import Container, LocalLinker
from theano.link.c.exceptions import MissingGXX
from theano.link.utils import gc_helper, map_storage, raise_with_op from theano.link.utils import gc_helper, map_storage, raise_with_op
...@@ -693,32 +692,6 @@ class Stack(VM): ...@@ -693,32 +692,6 @@ class Stack(VM):
self.node_cleared_order.append(final_index) self.node_cleared_order.append(final_index)
try:
# If cxx is explicitely set to an empty string, we do not want to import neither lazylinker C code
# nor lazylinker compiled C code from cache.
if not config.cxx:
raise MissingGXX(
"lazylinker will not be imported if theano.config.cxx is not set."
)
from theano.link.c import lazylinker_c
class CVM(lazylinker_c.CLazyLinker, VM):
def __init__(self, fgraph, *args, **kwargs):
self.fgraph = fgraph
lazylinker_c.CLazyLinker.__init__(self, *args, **kwargs)
# skip VM.__init__
except ImportError:
pass
except (OSError, MissingGXX) as e:
# OSError happens when g++ is not installed. In that case, we
# already changed the default linker to something else then CVM.
# Currently this is the py linker.
# Here we assert that the default linker is not cvm.
assert not config._config_var_dict["linker"].default.startswith("cvm"), e
class VMLinker(LocalLinker): class VMLinker(LocalLinker):
""" """
Class that satisfies the Linker interface by acting as a VM factory. Class that satisfies the Linker interface by acting as a VM factory.
...@@ -948,6 +921,9 @@ class VMLinker(LocalLinker): ...@@ -948,6 +921,9 @@ class VMLinker(LocalLinker):
callback_input=self.callback_input, callback_input=self.callback_input,
) )
elif self.use_cloop: elif self.use_cloop:
from theano.link.c.cvm import CVM
# create a map from nodes to ints and vars to ints # create a map from nodes to ints and vars to ints
nodes_idx = {} nodes_idx = {}
vars_idx = {} vars_idx = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论