提交 21218d77 authored 作者: PabloRoque's avatar PabloRoque 提交者: Ricardo Vieira

Clean test_Elemwise

上级 7efd1c59
......@@ -31,55 +31,47 @@ rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"inputs, input_vals, output_fn, exc",
"inputs, input_vals, output_fn",
[
(
[pt.vector()],
[rng.uniform(size=100).astype(config.floatX)],
lambda x: pt.gammaln(x),
None,
),
(
[pt.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.sigmoid(x),
None,
),
(
[pt.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.log1mexp(x),
None,
),
(
[pt.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.erf(x),
None,
),
(
[pt.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.erfc(x),
None,
),
(
[pt.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.erfcx(x),
None,
),
(
[pt.vector() for i in range(4)],
[rng.standard_normal(100).astype(config.floatX) for i in range(4)],
lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
None,
),
(
[pt.matrix(), pt.scalar()],
[rng.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: pt.switch(a, b, a),
None,
),
(
[pt.scalar(), pt.scalar()],
......@@ -88,7 +80,6 @@ rng = np.random.default_rng(42849)
np.array(1.0, dtype=config.floatX),
],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
),
(
[pt.vector(), pt.vector()],
......@@ -97,7 +88,6 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX),
],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
),
(
[pt.vector(), pt.vector()],
......@@ -106,15 +96,25 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX),
],
lambda x, y: scalar_my_multi_out(x, y),
None,
),
],
ids=[
"gammaln",
"sigmoid",
"log1mexp",
"erf",
"erfc",
"erfcx",
"complex_arithmetic",
"switch",
"add_inplace_scalar",
"add_inplace_vector",
"scalar_multi_out",
],
)
def test_Elemwise(inputs, input_vals, output_fn, exc):
def test_Elemwise(inputs, input_vals, output_fn):
outputs = output_fn(*inputs)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(
inputs,
outputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论