提交 d7a8c825 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba conversions for Clip, AllocDiag, ARange, MakeVector

上级 27f1ede5
...@@ -16,10 +16,13 @@ from aesara.compile.ops import DeepCopyOp, ViewOp ...@@ -16,10 +16,13 @@ from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.utils import compile_function_src, fgraph_to_python from aesara.link.utils import compile_function_src, fgraph_to_python
from aesara.scalar.basic import Cast, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
AllocDiag,
AllocEmpty, AllocEmpty,
ARange,
MakeVector,
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
...@@ -346,6 +349,17 @@ def numba_funcify_MakeSlice(op, **kwargs): ...@@ -346,6 +349,17 @@ def numba_funcify_MakeSlice(op, **kwargs):
return makeslice return makeslice
@numba_funcify.register(MakeVector)
def numba_funcify_MakeVector(op, **kwargs):
dtype = np.dtype(op.dtype)
@numba.njit
def makevector(*args):
return np.array([a.item() for a in args], dtype=dtype)
return makevector
@numba_funcify.register(Shape) @numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs): def numba_funcify_Shape(op, **kwargs):
@numba.njit @numba.njit
...@@ -445,6 +459,17 @@ def alloc(val, {", ".join(shape_var_names)}): ...@@ -445,6 +459,17 @@ def alloc(val, {", ".join(shape_var_names)}):
return numba.njit(alloc_fn) return numba.njit(alloc_fn)
@numba_funcify.register(AllocDiag)
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba.njit
def allocdiag(v):
return np.diag(v, k=offset)
return allocdiag
@numba_funcify.register(Second) @numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs): def numba_funcify_Second(op, node, **kwargs):
@numba.njit @numba.njit
...@@ -539,13 +564,28 @@ def numba_funcify_Rebroadcast(op, **kwargs): ...@@ -539,13 +564,28 @@ def numba_funcify_Rebroadcast(op, **kwargs):
return rebroadcast return rebroadcast
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
casted = typ.instance_type
sig = casted(casted, typ)
def codegen(context, builder, signature, args):
val, _ = args
context.nrt.incref(builder, signature.return_type, val)
return val
return sig, codegen
@numba_funcify.register(Cast) @numba_funcify.register(Cast)
def numba_funcify_Cast(op, **kwargs): def numba_funcify_Cast(op, **kwargs):
dtype = op.o_type.dtype
dtype = np.dtype(op.o_type.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit @numba.njit
def cast(x): def cast(x):
return np.array(x, dtype=dtype) return direct_cast(x, dtype)
return cast return cast
...@@ -589,3 +629,29 @@ def numba_funcify_ViewOp(op, **kwargs): ...@@ -589,3 +629,29 @@ def numba_funcify_ViewOp(op, **kwargs):
return x return x
return viewop return viewop
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba.njit
def clip(_x, _min, _max):
x = to_scalar(_x)
min = to_scalar(_min)
max = to_scalar(_max)
return np.where(x < min, min, to_scalar(np.where(x > max, max, x)))
return clip
@numba_funcify.register(ARange)
def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit
def arange(start, stop, step):
return np.arange(
to_scalar(start), to_scalar(stop), to_scalar(step), dtype=dtype
)
return arange
...@@ -91,9 +91,24 @@ def compare_numba_and_py( ...@@ -91,9 +91,24 @@ def compare_numba_and_py(
l[i] = v l[i] = v
return tuple(l) return tuple(l)
def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x
with mock.patch("aesara.link.numba.dispatch.numba.njit", lambda x: x), mock.patch( with mock.patch("aesara.link.numba.dispatch.numba.njit", lambda x: x), mock.patch(
"aesara.link.numba.dispatch.numba.vectorize", lambda x: x "aesara.link.numba.dispatch.numba.vectorize", lambda x: x
), mock.patch("aesara.link.numba.dispatch.tuple_setitem", py_tuple_setitem): ), mock.patch(
"aesara.link.numba.dispatch.tuple_setitem", py_tuple_setitem
), mock.patch(
"aesara.link.numba.dispatch.direct_cast", lambda x, dtype: x
), mock.patch(
"aesara.link.numba.dispatch.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
), mock.patch(
"aesara.link.numba.dispatch.to_scalar", py_to_scalar
):
aesara_numba_fn = function( aesara_numba_fn = function(
fn_inputs, fn_inputs,
fgraph.outputs, fgraph.outputs,
...@@ -372,6 +387,28 @@ def test_AllocEmpty(): ...@@ -372,6 +387,28 @@ def test_AllocEmpty():
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype) compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype)
@pytest.mark.parametrize(
"v, offset",
[
(set_test_value(aet.vector(), np.arange(10, dtype=config.floatX)), 0),
(set_test_value(aet.vector(), np.arange(10, dtype=config.floatX)), 1),
(set_test_value(aet.vector(), np.arange(10, dtype=config.floatX)), -1),
],
)
def test_AllocDiag(v, offset):
g = aetb.AllocDiag(offset=offset)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, new_order, inplace", "v, new_order, inplace",
[ [
...@@ -633,3 +670,77 @@ def test_Second(x, y): ...@@ -633,3 +670,77 @@ def test_Second(x, y):
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
) )
@pytest.mark.parametrize(
"v, min, max",
[
(set_test_value(aet.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(aet.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(aet.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0),
],
)
def test_Clip(v, min, max):
g = aes.clip(v, min, max)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"vals, dtype",
[
(
(
set_test_value(aet.scalar(), np.array(1, dtype=config.floatX)),
set_test_value(aet.scalar(), np.array(2, dtype=config.floatX)),
set_test_value(aet.scalar(), np.array(3, dtype=config.floatX)),
),
config.floatX,
),
],
)
def test_MakeVector(vals, dtype):
g = aetb.MakeVector(dtype)(*vals)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"start, stop, step, dtype",
[
(
set_test_value(aet.lscalar(), np.array(1)),
set_test_value(aet.lscalar(), np.array(10)),
set_test_value(aet.lscalar(), np.array(3)),
config.floatX,
),
],
)
def test_ARange(start, stop, step, dtype):
g = aetb.ARange(dtype)(start, stop, step)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论