提交 487ce550 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify Elemwise and Composite tests in tests.link.test_numba

上级 68a8e224
......@@ -258,23 +258,27 @@ def test_create_numba_signature(v, expected, force_scalar):
],
)
def test_Elemwise(inputs, input_vals, output_fn):
out_fg = FunctionGraph(inputs, [output_fn(*inputs)])
out_fg = FunctionGraph(outputs=[output_fn(*inputs)])
compare_numba_and_py(out_fg, input_vals)
@pytest.mark.parametrize(
"inputs, input_values",
"inputs, input_values, scalar_fn",
[
(
[aet.scalar("x"), aet.scalar("y")],
[np.array(10).astype(config.floatX), np.array(20).astype(config.floatX)],
[
np.array(10, dtype=config.floatX),
np.array(20, dtype=config.floatX),
],
lambda x, y: x + y * 2 + aes.exp(x - y),
),
],
)
def test_numba_Composite(inputs, input_values):
def test_numba_Composite(inputs, input_values, scalar_fn):
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
comp_op = Elemwise(Composite([x_s, y_s], [scalar_fn(x_s, y_s)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
compare_numba_and_py(out_fg, input_values)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论