提交 21218d77 authored 作者: PabloRoque's avatar PabloRoque 提交者: Ricardo Vieira

Clean test_Elemwise

上级 7efd1c59
...@@ -31,55 +31,47 @@ rng = np.random.default_rng(42849) ...@@ -31,55 +31,47 @@ rng = np.random.default_rng(42849)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, input_vals, output_fn, exc", "inputs, input_vals, output_fn",
[ [
( (
[pt.vector()], [pt.vector()],
[rng.uniform(size=100).astype(config.floatX)], [rng.uniform(size=100).astype(config.floatX)],
lambda x: pt.gammaln(x), lambda x: pt.gammaln(x),
None,
), ),
( (
[pt.vector()], [pt.vector()],
[rng.standard_normal(100).astype(config.floatX)], [rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.sigmoid(x), lambda x: pt.sigmoid(x),
None,
), ),
( (
[pt.vector()], [pt.vector()],
[rng.standard_normal(100).astype(config.floatX)], [rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.log1mexp(x), lambda x: pt.log1mexp(x),
None,
), ),
( (
[pt.vector()], [pt.vector()],
[rng.standard_normal(100).astype(config.floatX)], [rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.erf(x), lambda x: pt.erf(x),
None,
), ),
( (
[pt.vector()], [pt.vector()],
[rng.standard_normal(100).astype(config.floatX)], [rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.erfc(x), lambda x: pt.erfc(x),
None,
), ),
( (
[pt.vector()], [pt.vector()],
[rng.standard_normal(100).astype(config.floatX)], [rng.standard_normal(100).astype(config.floatX)],
lambda x: pt.erfcx(x), lambda x: pt.erfcx(x),
None,
), ),
( (
[pt.vector() for i in range(4)], [pt.vector() for i in range(4)],
[rng.standard_normal(100).astype(config.floatX) for i in range(4)], [rng.standard_normal(100).astype(config.floatX) for i in range(4)],
lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
None,
), ),
( (
[pt.matrix(), pt.scalar()], [pt.matrix(), pt.scalar()],
[rng.normal(size=(2, 2)).astype(config.floatX), 0.0], [rng.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: pt.switch(a, b, a), lambda a, b: pt.switch(a, b, a),
None,
), ),
( (
[pt.scalar(), pt.scalar()], [pt.scalar(), pt.scalar()],
...@@ -88,7 +80,6 @@ rng = np.random.default_rng(42849) ...@@ -88,7 +80,6 @@ rng = np.random.default_rng(42849)
np.array(1.0, dtype=config.floatX), np.array(1.0, dtype=config.floatX),
], ],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
), ),
( (
[pt.vector(), pt.vector()], [pt.vector(), pt.vector()],
...@@ -97,7 +88,6 @@ rng = np.random.default_rng(42849) ...@@ -97,7 +88,6 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
], ],
lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)), lambda x, y: pti.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
), ),
( (
[pt.vector(), pt.vector()], [pt.vector(), pt.vector()],
...@@ -106,15 +96,25 @@ rng = np.random.default_rng(42849) ...@@ -106,15 +96,25 @@ rng = np.random.default_rng(42849)
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
], ],
lambda x, y: scalar_my_multi_out(x, y), lambda x, y: scalar_my_multi_out(x, y),
None,
), ),
], ],
ids=[
"gammaln",
"sigmoid",
"log1mexp",
"erf",
"erfc",
"erfcx",
"complex_arithmetic",
"switch",
"add_inplace_scalar",
"add_inplace_vector",
"scalar_multi_out",
],
) )
def test_Elemwise(inputs, input_vals, output_fn, exc): def test_Elemwise(inputs, input_vals, output_fn):
outputs = output_fn(*inputs) outputs = output_fn(*inputs)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py( compare_numba_and_py(
inputs, inputs,
outputs, outputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论