提交 980c4c2c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use params in Numba's default Op.perform implementation

上级 9a1ab3e8
...@@ -16,7 +16,7 @@ from numba.extending import box ...@@ -16,7 +16,7 @@ from numba.extending import box
from aesara import config from aesara import config
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply from aesara.graph.basic import Apply, NoParams
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
...@@ -330,16 +330,41 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -330,16 +330,41 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
else: else:
ret_sig = get_numba_type(node.outputs[0].type) ret_sig = get_numba_type(node.outputs[0].type)
@numba_njit output_types = tuple(out.type for out in node.outputs)
def perform(*inputs): params = node.run_params()
with numba.objmode(ret=ret_sig):
if params is not NoParams:
params_val = dict(node.params_type.filter(params))
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs, params_val)
return outputs
else:
def py_perform(inputs):
outputs = [[None] for i in range(n_outputs)] outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs) op.perform(node, inputs, outputs)
outputs = tuple([o[0] for o in outputs]) return outputs
if n_outputs == 1: if n_outputs == 1:
ret = outputs[0]
def py_perform_return(inputs):
return output_types[0].filter(py_perform(inputs)[0][0])
else: else:
ret = outputs
def py_perform_return(inputs):
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs))
)
@numba_njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
ret = py_perform_return(inputs)
return ret return ret
return perform return perform
......
...@@ -30,6 +30,7 @@ from aesara.ifelse import ifelse ...@@ -30,6 +30,7 @@ from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch import numba_typify from aesara.link.numba.dispatch import numba_typify
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.raise_op import assert_op
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.utils import until from aesara.scan.utils import until
...@@ -1396,6 +1397,44 @@ def test_perform(inputs, op, exc): ...@@ -1396,6 +1397,44 @@ def test_perform(inputs, op, exc):
) )
def test_perform_params():
"""This tests for `Op.perform` implementations that require the `params` arguments."""
x = at.vector()
x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x, np.array(True))
if not isinstance(out, (list, tuple)):
out = [out]
out_fg = FunctionGraph([x], out)
with pytest.warns(UserWarning, match=".*object mode.*"):
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
def test_perform_type_convert():
"""This tests the use of `Type.filter` in `objmode`.
The `Op.perform` takes a single input that it returns as-is, but it gets a
native scalar and it's supposed to return an `np.ndarray`.
"""
x = at.vector()
x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
if not isinstance(out, (list, tuple)):
out = [out]
out_fg = FunctionGraph([x], out)
with pytest.warns(UserWarning, match=".*object mode.*"):
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"val", "val",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论