提交 1ee6f62c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement a __str__ for _CThunk class

上级 4a99b673
...@@ -1771,6 +1771,9 @@ class _CThunk: ...@@ -1771,6 +1771,9 @@ class _CThunk:
raise raise
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
def __str__(self):
return f"{type(self).__name__}({self.module})"
class OpWiseCLinker(LocalLinker): class OpWiseCLinker(LocalLinker):
""" """
......
...@@ -3,8 +3,8 @@ import pytest ...@@ -3,8 +3,8 @@ import pytest
import aesara import aesara
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.graph import fg
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp from aesara.graph.op import COp
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.link.basic import PerformLinker from aesara.link.basic import PerformLinker
...@@ -180,27 +180,30 @@ def inputs(): ...@@ -180,27 +180,30 @@ def inputs():
return x, y, z return x, y, z
def Env(inputs, outputs):
e = fg.FunctionGraph(inputs, outputs)
return e
################
# Test CLinker #
################
@pytest.mark.skipif( @pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test." not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_clinker_straightforward(): def test_clinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = CLinker().accept(Env([x, y, z], [e])) lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cthunk_str():
x = double("x")
y = double("y")
e = add(x, y)
lnk = CLinker().accept(FunctionGraph([x, y], [e]))
cthunk, input_storage, output_storage = lnk.make_thunk()
assert str(cthunk).startswith("_CThunk")
assert "module" in str(cthunk)
@pytest.mark.skipif( @pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test." not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
) )
...@@ -208,7 +211,7 @@ def test_clinker_literal_inlining(): ...@@ -208,7 +211,7 @@ def test_clinker_literal_inlining():
x, y, z = inputs() x, y, z = inputs()
z = Constant(tdouble, 4.12345678) z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = CLinker().accept(Env([x, y], [e])) lnk = CLinker().accept(FunctionGraph([x, y], [e]))
fn = lnk.make_function() fn = lnk.make_function()
assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9 assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9
code = lnk.code_gen() code = lnk.code_gen()
...@@ -253,7 +256,7 @@ def test_clinker_literal_cache(): ...@@ -253,7 +256,7 @@ def test_clinker_literal_cache():
def test_clinker_single_node(): def test_clinker_single_node():
x, y, z = inputs() x, y, z = inputs()
node = add.make_node(x, y) node = add.make_node(x, y)
lnk = CLinker().accept(Env(node.inputs, node.outputs)) lnk = CLinker().accept(FunctionGraph(node.inputs, node.outputs))
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 7.0) == 9 assert fn(2.0, 7.0) == 9
...@@ -265,7 +268,7 @@ def test_clinker_dups(): ...@@ -265,7 +268,7 @@ def test_clinker_dups():
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
e = add(x, x) e = add(x, x)
lnk = CLinker().accept(Env([x, x], [e])) lnk = CLinker().accept(FunctionGraph([x, x], [e]))
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 2.0) == 4 assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
...@@ -278,7 +281,7 @@ def test_clinker_not_used_inputs(): ...@@ -278,7 +281,7 @@ def test_clinker_not_used_inputs():
# Testing that unused inputs are allowed. # Testing that unused inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
e = add(x, y) e = add(x, y)
lnk = CLinker().accept(Env([x, y, z], [e])) lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 1.5, 1.0) == 3.5 assert fn(2.0, 1.5, 1.0) == 3.5
...@@ -290,20 +293,16 @@ def test_clinker_dups_inner(): ...@@ -290,20 +293,16 @@ def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
e = add(mul(y, y), add(x, z)) e = add(mul(y, y), add(x, z))
lnk = CLinker().accept(Env([x, y, z], [e])) lnk = CLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
assert fn(1.0, 2.0, 3.0) == 8.0 assert fn(1.0, 2.0, 3.0) == 8.0
######################
# Test OpWiseCLinker #
######################
# slow on linux, but near sole test and very central # slow on linux, but near sole test and very central
def test_opwiseclinker_straightforward(): def test_opwiseclinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), bad_sub(bad_sub(x, y), z))
lnk = OpWiseCLinker().accept(Env([x, y, z], [e])) lnk = OpWiseCLinker().accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
if aesara.config.cxx: if aesara.config.cxx:
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
...@@ -316,7 +315,7 @@ def test_opwiseclinker_constant(): ...@@ -316,7 +315,7 @@ def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name="x") x = Constant(tdouble, 7.2, name="x")
e = add(mul(x, y), mul(y, z)) e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(1.5, 3.0) res = fn(1.5, 3.0)
assert res == 15.3 assert res == 15.3
...@@ -331,15 +330,10 @@ def _my_checker(x, y): ...@@ -331,15 +330,10 @@ def _my_checker(x, y):
raise MyExc("Output mismatch.", {"performlinker": x[0], "clinker": y[0]}) raise MyExc("Output mismatch.", {"performlinker": x[0], "clinker": y[0]})
###################
# Test DualLinker #
###################
def test_duallinker_straightforward(): def test_duallinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker=_my_checker).accept(Env([x, y, z], [e])) lnk = DualLinker(checker=_my_checker).accept(FunctionGraph([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
assert res == 15.3 assert res == 15.3
...@@ -352,7 +346,7 @@ def test_duallinker_mismatch(): ...@@ -352,7 +346,7 @@ def test_duallinker_mismatch():
x, y, z = inputs() x, y, z = inputs()
# bad_sub is correct in C but erroneous in Python # bad_sub is correct in C but erroneous in Python
e = bad_sub(mul(x, y), mul(y, z)) e = bad_sub(mul(x, y), mul(y, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
lnk = DualLinker(checker=_my_checker).accept(g) lnk = DualLinker(checker=_my_checker).accept(g)
fn = lnk.make_function() fn = lnk.make_function()
...@@ -371,11 +365,6 @@ def test_duallinker_mismatch(): ...@@ -371,11 +365,6 @@ def test_duallinker_mismatch():
fn(1.0, 2.0, 3.0) fn(1.0, 2.0, 3.0)
################################
# Test that failure code works #
################################
class AddFail(Binary): class AddFail(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
...@@ -399,7 +388,7 @@ def test_c_fail_error(): ...@@ -399,7 +388,7 @@ def test_c_fail_error():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name="x") x = Constant(tdouble, 7.2, name="x")
e = add_fail(mul(x, y), mul(y, z)) e = add_fail(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(FunctionGraph([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
fn(1.5, 3.0) fn(1.5, 3.0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论