提交 fe9f2580 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Remove Composite test and create a parameterized Elemwise test

上级 7b311b65
...@@ -3,9 +3,8 @@ from functools import partial ...@@ -3,9 +3,8 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import aesara
import aesara.scalar.basic as aes
import aesara.tensor as aet import aesara.tensor as aet
from aesara import config
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
...@@ -46,7 +45,11 @@ def compare_numba_and_py( ...@@ -46,7 +45,11 @@ def compare_numba_and_py(
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_numba_fn = function(fn_inputs, fgraph.outputs, mode=numba_mode) aesara_numba_fn = function(
fn_inputs,
fgraph.outputs,
mode=numba_mode,
)
numba_res = aesara_numba_fn(*inputs) numba_res = aesara_numba_fn(*inputs)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
...@@ -61,53 +64,19 @@ def compare_numba_and_py( ...@@ -61,53 +64,19 @@ def compare_numba_and_py(
return numba_res return numba_res
def test_Composite(): @pytest.mark.parametrize(
opts = Query(include=["fusion"], exclude=["cxx_only", "BlasOpt"]) "inputs, input_vals, output_fn",
numba_mode = Mode(NumbaLinker(), opts) [
py_mode = Mode("py", opts) (
[aet.vector() for i in range(4)],
y = aet.vector("y") [np.random.randn(100).astype(config.floatX) for i in range(4)],
x = aet.vector("x") lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
)
z = (x + y) * (x + y) * y ],
)
func = aesara.function([x, y], [z], mode=py_mode) def test_Elemwise(inputs, input_vals, output_fn):
numba_fn = aesara.function([x, y], [z], mode=numba_mode) out_fg = FunctionGraph(inputs, [output_fn(*inputs)])
compare_numba_and_py(out_fg, input_vals)
# Make sure the graph had a `Composite` `Op` in it
composite_op = numba_fn.maker.fgraph.outputs[0].owner.op.scalar_op
assert isinstance(composite_op, aes.Composite)
x_val = np.random.randn(1000)
y_val = np.random.randn(1000)
res = func(x_val, y_val) # Answer from python mode compilation of FunctionGraph
numba_res = numba_fn(x_val, y_val) # Answer from Numba converted FunctionGraph
assert np.array_equal(res, numba_res)
y1 = aet.vector("y1")
x1 = aet.vector("x1")
z = (x + y) * (x1 + y1) * y
x1_val = np.random.randn(1000)
y1_val = np.random.randn(1000)
func = aesara.function([x, y, x1, y1], [z], mode=py_mode)
numba_fn = aesara.function([x, y, x1, y1], [z], mode=numba_mode)
composite_op = numba_fn.maker.fgraph.outputs[0].owner.op.scalar_op
assert isinstance(composite_op, aes.Composite)
res = func(
x_val, y_val, x1_val, y1_val
) # Answer from python mode compilation of FunctionGraph
numba_res = numba_fn(
x_val, y_val, x1_val, y1_val
) # Answer from Numba converted FunctionGraph
assert np.array_equal(res, numba_res)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论