提交 adc9ce96 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added test for Composite in link.test_numba

上级 0b3e6296
......@@ -3,6 +3,7 @@ from functools import partial
import numpy as np
import pytest
import aesara.scalar as aes
import aesara.tensor as aet
from aesara import config
from aesara.compile.function import function
......@@ -11,7 +12,10 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
......@@ -79,6 +83,23 @@ def test_Elemwise(inputs, input_vals, output_fn):
compare_numba_and_py(out_fg, input_vals)
@pytest.mark.parametrize(
"inputs, input_values",
[
(
[scalar("x"), scalar("y")],
[np.array(10).astype(config.floatX), np.array(20).astype(config.floatX)],
),
],
)
def test_numba_Composite(inputs, input_values):
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
compare_numba_and_py(out_fg, input_values)
@pytest.mark.parametrize(
"x, indices",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论