提交 0dd73dc4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up some tests in tests.link.test_vm

上级 866ab4bd
...@@ -1197,3 +1197,12 @@ class VMLinker(LocalLinker): ...@@ -1197,3 +1197,12 @@ class VMLinker(LocalLinker):
self.allow_partial_eval = None self.allow_partial_eval = None
if not hasattr(self, "callback_input"): if not hasattr(self, "callback_input"):
self.callback_input = None self.callback_input = None
def __repr__(self):
args_str = ", ".join(
[
f"{name}={getattr(self, name)}"
for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
]
)
return f"{type(self).__name__}({args_str})"
...@@ -15,10 +15,11 @@ from aesara.ifelse import ifelse ...@@ -15,10 +15,11 @@ from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import map_storage from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, VMLinker from aesara.link.vm import VM, Loop, LoopGC, Stack, VMLinker
from aesara.tensor.math import cosh, tanh from aesara.tensor.math import cosh, tanh
from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from tests import unittest_tools as utt
class SomeOp(Op): class SomeOp(Op):
...@@ -157,7 +158,15 @@ def test_speed(): ...@@ -157,7 +158,15 @@ def test_speed():
time_numpy() time_numpy()
def test_speed_lazy(): @pytest.mark.parametrize(
"linker",
[
VMLinker(),
VMLinker(allow_gc=False),
VMLinker(allow_gc=False, use_cloop=True),
],
)
def test_speed_lazy(linker):
# TODO FIXME: This isn't a real test. # TODO FIXME: This isn't a real test.
def build_graph(x, depth=5): def build_graph(x, depth=5):
...@@ -166,105 +175,96 @@ def test_speed_lazy(): ...@@ -166,105 +175,96 @@ def test_speed_lazy():
z = ifelse(z[0] > 0, -z, z) z = ifelse(z[0] > 0, -z, z)
return z return z
def time_linker(name, linker): steps_a = 10
steps_a = 10 steps_b = 100
steps_b = 100 x = vector()
x = vector() a = build_graph(x, steps_a)
a = build_graph(x, steps_a) b = build_graph(x, steps_b)
b = build_graph(x, steps_b)
f_a = function([x], a, mode=Mode(optimizer=None, linker=linker()))
f_b = function([x], b, mode=Mode(optimizer=None, linker=linker()))
f_a([2.0]) f_a = function([x], a, mode=Mode(optimizer=None, linker=linker))
t0 = time.time() f_b = function([x], b, mode=Mode(optimizer=None, linker=linker))
f_a([2.0])
t1 = time.time()
f_b([2.0]) f_a([2.0])
t0 = time.time()
f_a([2.0])
t1 = time.time()
t2 = time.time() f_b([2.0])
f_b([2.0])
t3 = time.time()
t_a = t1 - t0 t2 = time.time()
t_b = t3 - t2 f_b([2.0])
t3 = time.time()
print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop") t_a = t1 - t0
t_b = t3 - t2
time_linker("vmLinker", VMLinker) print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop")
time_linker("vmLinker_nogc", lambda: VMLinker(allow_gc=False))
if config.cxx:
time_linker("vmLinker_C", lambda: VMLinker(allow_gc=False, use_cloop=True))
def test_partial_function(): @pytest.mark.parametrize(
from tests import unittest_tools as utt "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"]
)
def test_partial_function(linker):
def check_partial_function(linker_name): x = scalar("input")
x = scalar("input") y = x**2
y = x**2 f = function(
f = function( [x], [y + 7, y - 9, y / 14.0], mode=Mode(optimizer=None, linker=linker)
[x], [y + 7, y - 9, y / 14.0], mode=Mode(optimizer=None, linker=linker_name) )
)
assert f(3, output_subset=[0, 1, 2]) == f(3) if linker == "cvm":
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]] from aesara.link.c.cvm import CVM
utt.assert_allclose(f(5), np.array([32.0, 16.0, 1.7857142857142858]))
check_partial_function(VMLinker(allow_partial_eval=True, use_cloop=False)) assert isinstance(f.fn, CVM)
if not config.cxx: else:
pytest.skip("Need cxx for this test") assert isinstance(f.fn, Stack)
check_partial_function("cvm")
assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
@pytest.mark.skipif( utt.assert_allclose(f(5), np.array([32.0, 16.0, 1.7857142857142858]))
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_partial_function_with_output_keys():
def check_partial_function_output_keys(linker_name):
x = scalar("input")
y = 3 * x
f = function(
[x], {"a": y * 5, "b": y - 7}, mode=Mode(optimizer=None, linker=linker_name)
)
assert f(5, output_subset=["a"])["a"] == f(5)["a"]
check_partial_function_output_keys( @pytest.mark.parametrize(
VMLinker(allow_partial_eval=True, use_cloop=False) "linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"]
)
def test_partial_function_with_output_keys(linker):
x = scalar("input")
y = 3 * x
f = function(
[x], {"a": y * 5, "b": y - 7}, mode=Mode(optimizer=None, linker=linker)
) )
check_partial_function_output_keys("cvm")
assert f(5, output_subset=["a"])["a"] == f(5)["a"]
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test." @pytest.mark.parametrize(
"linker", [VMLinker(allow_partial_eval=True, use_cloop=False), "cvm"]
) )
def test_partial_function_with_updates(): def test_partial_function_with_updates(linker):
def check_updates(linker_name): x = lscalar("input")
x = lscalar("input") y = shared(np.asarray(1, "int64"), name="global")
y = shared(np.asarray(1, "int64"), name="global")
f = function( mode = Mode(optimizer=None, linker=linker)
[x],
[x, x + 34],
updates=[(y, x + 1)],
mode=Mode(optimizer=None, linker=linker_name),
)
g = function(
[x],
[x - 6],
updates=[(y, y + 3)],
mode=Mode(optimizer=None, linker=linker_name),
)
assert f(3, output_subset=[]) == [] f = function(
assert y.get_value() == 4 [x],
assert g(30, output_subset=[0]) == [24] [x, x + 34],
assert g(40, output_subset=[]) == [] updates=[(y, x + 1)],
assert y.get_value() == 10 mode=mode,
)
g = function(
[x],
[x - 6],
updates=[(y, y + 3)],
mode=mode,
)
check_updates(VMLinker(allow_partial_eval=True, use_cloop=False)) assert f(3, output_subset=[]) == []
check_updates("cvm") assert y.get_value() == 4
assert g(30, output_subset=[0]) == [24]
assert g(40, output_subset=[]) == []
assert y.get_value() == 10
def test_allow_gc_cvm(): def test_allow_gc_cvm():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论