提交 1ee6f62c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement a __str__ for _CThunk class

上级 4a99b673
......@@ -1771,6 +1771,9 @@ class _CThunk:
raise
raise exc_value.with_traceback(exc_trace)
def __str__(self):
return f"{type(self).__name__}({self.module})"
class OpWiseCLinker(LocalLinker):
"""
......
......@@ -3,8 +3,8 @@ import pytest
import aesara
from aesara.compile.mode import Mode
from aesara.graph import fg
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp
from aesara.graph.type import CType
from aesara.link.basic import PerformLinker
......@@ -180,27 +180,30 @@ def inputs():
return x, y, z
def Env(inputs, outputs):
e = fg.FunctionGraph(inputs, outputs)
return e
################
# Test CLinker #
################
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_clinker_straightforward():
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = CLinker().accept(Env([x, y, z], [e]))
lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cthunk_str():
x = double("x")
y = double("y")
e = add(x, y)
lnk = CLinker().accept(FunctionGraph([x, y], [e]))
cthunk, input_storage, output_storage = lnk.make_thunk()
assert str(cthunk).startswith("_CThunk")
assert "module" in str(cthunk)
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
......@@ -208,7 +211,7 @@ def test_clinker_literal_inlining():
x, y, z = inputs()
z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = CLinker().accept(Env([x, y], [e]))
lnk = CLinker().accept(FunctionGraph([x, y], [e]))
fn = lnk.make_function()
assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9
code = lnk.code_gen()
......@@ -253,7 +256,7 @@ def test_clinker_literal_cache():
def test_clinker_single_node():
x, y, z = inputs()
node = add.make_node(x, y)
lnk = CLinker().accept(Env(node.inputs, node.outputs))
lnk = CLinker().accept(FunctionGraph(node.inputs, node.outputs))
fn = lnk.make_function()
assert fn(2.0, 7.0) == 9
......@@ -265,7 +268,7 @@ def test_clinker_dups():
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, x)
lnk = CLinker().accept(Env([x, x], [e]))
lnk = CLinker().accept(FunctionGraph([x, x], [e]))
fn = lnk.make_function()
assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined
......@@ -278,7 +281,7 @@ def test_clinker_not_used_inputs():
# Testing that unused inputs are allowed.
x, y, z = inputs()
e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e]))
lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5
......@@ -290,20 +293,16 @@ def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker().accept(Env([x, y, z], [e]))
lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function()
assert fn(1.0, 2.0, 3.0) == 8.0
######################
# Test OpWiseCLinker #
######################
# slow on linux, but near sole test and very central
def test_opwiseclinker_straightforward():
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = OpWiseCLinker().accept(Env([x, y, z], [e]))
lnk = OpWiseCLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function()
if aesara.config.cxx:
assert fn(2.0, 2.0, 2.0) == 2.0
......@@ -316,7 +315,7 @@ def test_opwiseclinker_constant():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name="x")
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e]))
fn = lnk.make_function()
res = fn(1.5, 3.0)
assert res == 15.3
......@@ -331,15 +330,10 @@ def _my_checker(x, y):
raise MyExc("Output mismatch.", {"performlinker": x[0], "clinker": y[0]})
###################
# Test DualLinker #
###################
def test_duallinker_straightforward():
x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker=_my_checker).accept(Env([x, y, z], [e]))
lnk = DualLinker(checker=_my_checker).accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0)
assert res == 15.3
......@@ -352,7 +346,7 @@ def test_duallinker_mismatch():
x, y, z = inputs()
# bad_sub is correct in C but erroneous in Python
e = bad_sub(mul(x, y), mul(y, z))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
lnk = DualLinker(checker=_my_checker).accept(g)
fn = lnk.make_function()
......@@ -371,11 +365,6 @@ def test_duallinker_mismatch():
fn(1.0, 2.0, 3.0)
################################
# Test that failure code works #
################################
class AddFail(Binary):
def c_code(self, node, name, inp, out, sub):
x, y = inp
......@@ -399,7 +388,7 @@ def test_c_fail_error():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name="x")
e = add_fail(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e]))
fn = lnk.make_function()
with pytest.raises(RuntimeError):
fn(1.5, 3.0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论