提交 adc9ce96 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added test for Composite in link.test_numba

上级 0b3e6296
...@@ -3,6 +3,7 @@ from functools import partial ...@@ -3,6 +3,7 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import aesara.scalar as aes
import aesara.tensor as aet import aesara.tensor as aet
from aesara import config from aesara import config
from aesara.compile.function import function from aesara.compile.function import function
...@@ -11,7 +12,10 @@ from aesara.compile.sharedvalue import SharedVariable ...@@ -11,7 +12,10 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.tensor import subtensor as aet_subtensor from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
...@@ -79,6 +83,23 @@ def test_Elemwise(inputs, input_vals, output_fn): ...@@ -79,6 +83,23 @@ def test_Elemwise(inputs, input_vals, output_fn):
compare_numba_and_py(out_fg, input_vals) compare_numba_and_py(out_fg, input_vals)
@pytest.mark.parametrize(
"inputs, input_values",
[
(
[scalar("x"), scalar("y")],
[np.array(10).astype(config.floatX), np.array(20).astype(config.floatX)],
),
],
)
def test_numba_Composite(inputs, input_values):
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
compare_numba_and_py(out_fg, input_values)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论