提交 25162ed0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba conversions for Ops in aesara.tensor.extra_ops

上级 f05a185b
......@@ -12,6 +12,7 @@ from numba import types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box
from numpy.core.multiarray import normalize_axis_index
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply
......@@ -45,6 +46,19 @@ from aesara.tensor.basic import (
TensorFromScalar,
)
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
DiffOp,
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -968,3 +982,355 @@ def numba_funcify_Eye(op, **kwargs):
return np.eye(to_scalar(N), to_scalar(M), to_scalar(k), dtype=dtype)
return eye
@numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba.njit
def bartlett(x):
return np.bartlett(to_scalar(x))
return bartlett
@numba_funcify.register(CumOp)
def numba_funcify_CumOp(op, node, **kwargs):
axis = op.axis
mode = op.mode
ndim = node.outputs[0].ndim
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
if mode == "add":
np_func = np.add
identity = 0
else:
np_func = np.multiply
identity = 1
@numba.njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
return x.astype(out_dtype)
x_axis_first = x.transpose(reaxis_first)
res = np.empty(x_axis_first.shape, dtype=out_dtype)
for m in range(x.shape[axis]):
if m == 0:
np_func(identity, x_axis_first[m], res[m])
else:
np_func(res[m - 1], x_axis_first[m], res[m])
return res.transpose(reaxis_first)
return cumop
@numba_funcify.register(DiffOp)
def numba_funcify_DiffOp(op, node, **kwargs):
n = op.n
axis = op.axis
ndim = node.inputs[0].ndim
dtype = node.outputs[0].dtype
axis = normalize_axis_index(axis, ndim)
slice1 = [slice(None)] * ndim
slice2 = [slice(None)] * ndim
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1 = tuple(slice1)
slice2 = tuple(slice2)
op = np.not_equal if dtype == "bool" else np.subtract
@numba.njit(boundscheck=False)
def diffop(x):
res = x.copy()
for _ in range(n):
res = op(res[slice1], res[slice2])
return res
return diffop
@numba_funcify.register(FillDiagonal)
def numba_funcify_FillDiagonal(op, **kwargs):
@numba.njit
def filldiagonal(a, val):
np.fill_diagonal(a, val)
return a
return filldiagonal
@numba_funcify.register(FillDiagonalOffset)
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba.njit
def filldiagonaloffset(a, val, offset):
height, width = a.shape
if offset >= 0:
start = to_scalar(offset)
num_of_step = min(min(width, height), width - offset)
else:
start = -to_scalar(offset) * a.shape[1]
num_of_step = min(min(width, height), height + offset)
step = a.shape[1] + 1
end = start + step * num_of_step
b = a.ravel()
b[start:end:step] = val
# TODO: This isn't implemented in Numba
# a.flat[start:end:step] = val
# return a
return b.reshape(a.shape)
return filldiagonaloffset
@numba_funcify.register(RavelMultiIndex)
def numba_funcify_RavelMultiIndex(op, node, **kwargs):
mode = op.mode
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)
if mode == "raise":
@numba.njit
def mode_fn(*args):
raise ValueError("invalid entry in coordinates array")
elif mode == "wrap":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = v % d
elif mode == "clip":
@numba.njit(inline="always")
def mode_fn(new_arr, i, j, v, d):
new_arr[i, j] = min(max(v, 0), d - 1)
if node.inputs[0].ndim == 0:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
if b < 0 or b >= shape[i]:
mode_fn(new_arr, i, 0, b, shape[i])
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return np.array(a.dot(new_arr.T), dtype=np.int64)
else:
@numba.njit
def ravelmultiindex(*inp):
shape = inp[-1]
arr = np.stack(inp[:-1])
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
a = np.ones(len(shape), dtype=np.float64)
a[: len(shape) - 1] = np.cumprod(shape[-1:0:-1])[::-1]
return a.dot(new_arr.T).astype(np.int64)
return ravelmultiindex
@numba_funcify.register(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
return ret
else:
repeats_ndim = node.inputs[1].ndim
if repeats_ndim == 0:
@numba.njit
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
@numba.njit
def repeatop(x, repeats):
return np.repeat(x, repeats)
return repeatop
@numba_funcify.register(Unique)
def numba_funcify_Unique(op, node, **kwargs):
axis = op.axis
use_python = False
if axis is not None:
use_python = True
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
returns_multi = return_index or return_inverse or return_counts
use_python |= returns_multi
if not use_python:
@numba.njit
def unique(x):
return np.unique(x)
else:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` and/or `return_*` arguments to `numpy.unique`."
),
UserWarning,
)
if returns_multi:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def unique(x):
with numba.objmode(ret=ret_sig):
ret = np.unique(x, return_index, return_inverse, return_counts, axis)
return ret
return unique
@numba_funcify.register(UnravelIndex)
def numba_funcify_UnravelIndex(op, node, **kwargs):
order = op.order
if order != "C":
raise NotImplementedError(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
if len(node.outputs) == 1:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return arr
else:
@numba.njit(inline="always")
def maybe_expand_dim(arr):
return np.expand_dims(arr, 1)
@numba.njit
def unravelindex(arr, shape):
a = np.ones(len(shape), dtype=np.int64)
a[1:] = shape[:0:-1]
a = np.cumprod(a)[::-1]
# Aesara actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return ((maybe_expand_dim(arr) // a) % shape).T
return unravelindex
@numba_funcify.register(SearchsortedOp)
def numba_funcify_Searchsorted(op, node, **kwargs):
side = op.side
use_python = False
if len(node.inputs) == 3:
use_python = True
if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`sorter` argument to `numpy.searchsorted`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def searchsorted(a, v, sorter):
with numba.objmode(ret=ret_sig):
ret = np.searchsorted(a, v, side, sorter)
return ret
else:
@numba.njit
def searchsorted(a, v):
return np.searchsorted(a, v, side)
return searchsorted
@numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs):
warnings.warn(
("Numba will use object mode to allow the " "use of `numpy.broadcast_to`."),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def broadcastto(x, *shape):
with numba.objmode(ret=ret_sig):
ret = np.broadcast_to(x, shape)
return ret
return broadcastto
......@@ -24,6 +24,7 @@ from aesara.link.numba.dispatch import create_numba_signature, get_numba_type
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import extra_ops
from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -1148,3 +1149,521 @@ def test_perform(inputs, op, exc):
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"val",
[
set_test_value(aet.lscalar(), np.array(6, dtype="int64")),
],
)
def test_Bartlett(val):
g = extra_ops.bartlett(val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"val, axis, mode",
[
(
set_test_value(
aet.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1))
),
1,
"add",
),
(
set_test_value(
aet.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
0,
"add",
),
(
set_test_value(
aet.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
1,
"add",
),
(
set_test_value(
aet.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
0,
"mul",
),
(
set_test_value(
aet.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
1,
"mul",
),
],
)
def test_CumOp(val, axis, mode):
g = extra_ops.CumOp(axis=axis, mode=mode)(val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"val, n, axis",
[
(
set_test_value(
aet.matrix(), np.random.normal(size=(3, 2)).astype(config.floatX)
),
0,
0,
),
(
set_test_value(
aet.matrix(), np.random.normal(size=(3, 2)).astype(config.floatX)
),
0,
1,
),
(
set_test_value(
aet.matrix(), np.random.normal(size=(3, 2)).astype(config.floatX)
),
1,
0,
),
(
set_test_value(
aet.matrix(), np.random.normal(size=(3, 2)).astype(config.floatX)
),
1,
1,
),
(
set_test_value(aet.lmatrix(), np.random.poisson(size=(3, 2))),
0,
0,
),
],
)
def test_DiffOp(val, axis, n):
g = extra_ops.DiffOp(n=n, axis=axis)(val)
g_fg = FunctionGraph(outputs=[g])
(res,) = compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"a, val",
[
(
set_test_value(aet.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
)
],
)
def test_FillDiagonal(a, val):
g = extra_ops.FillDiagonal()(a, val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"a, val, offset",
[
(
set_test_value(aet.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
set_test_value(aet.lscalar(), np.array(-1, dtype="int64")),
),
(
set_test_value(aet.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
set_test_value(aet.lscalar(), np.array(0, dtype="int64")),
),
(
set_test_value(aet.lmatrix(), np.zeros((10, 3), dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
),
],
)
def test_FillDiagonalOffset(a, val, offset):
g = extra_ops.FillDiagonalOffset()(a, val, offset)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"arr, shape, mode, order, exc",
[
(
tuple(set_test_value(aet.lscalar(), v) for v in np.array([0])),
set_test_value(aet.lvector(), np.array([2])),
"raise",
"C",
None,
),
(
tuple(set_test_value(aet.lscalar(), v) for v in np.array([0, 0, 3])),
set_test_value(aet.lvector(), np.array([2, 3, 4])),
"raise",
"C",
None,
),
(
tuple(
set_test_value(aet.lvector(), v)
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
set_test_value(aet.lvector(), np.array([2, 3, 4])),
"raise",
"C",
None,
),
(
tuple(
set_test_value(aet.lvector(), v)
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
set_test_value(aet.lvector(), np.array([2, 3, 4])),
"raise",
"F",
NotImplementedError,
),
(
tuple(
set_test_value(aet.lvector(), v)
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
set_test_value(aet.lvector(), np.array([2, 3, 4])),
"raise",
"C",
ValueError,
),
(
tuple(
set_test_value(aet.lvector(), v)
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
set_test_value(aet.lvector(), np.array([2, 3, 4])),
"wrap",
"C",
None,
),
(
tuple(
set_test_value(aet.lvector(), v)
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
set_test_value(aet.lvector(), np.array([2, 3, 4])),
"clip",
"C",
None,
),
],
)
def test_RavelMultiIndex(arr, shape, mode, order, exc):
g = extra_ops.RavelMultiIndex(mode, order)(*(arr + (shape,)))
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, repeats, axis, exc",
[
(
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
set_test_value(aet.lscalar(), np.array(0, dtype="int64")),
None,
None,
),
(
set_test_value(aet.lmatrix(), np.zeros((2, 2), dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
None,
None,
),
(
set_test_value(aet.lvector(), np.arange(2, dtype="int64")),
set_test_value(aet.lvector(), np.array([1, 1], dtype="int64")),
None,
None,
),
(
set_test_value(aet.lmatrix(), np.zeros((2, 2), dtype="int64")),
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
0,
UserWarning,
),
],
)
def test_Repeat(x, repeats, axis, exc):
g = extra_ops.Repeat(axis)(x, repeats)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, axis, return_index, return_inverse, return_counts, exc",
[
(
set_test_value(aet.lscalar(), np.array(1, dtype="int64")),
None,
False,
False,
False,
None,
),
(
set_test_value(aet.lvector(), np.array([1, 1, 2], dtype="int64")),
None,
False,
False,
False,
None,
),
(
set_test_value(aet.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")),
None,
False,
False,
False,
None,
),
(
set_test_value(
aet.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")
),
0,
False,
False,
False,
UserWarning,
),
(
set_test_value(
aet.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")
),
0,
True,
True,
True,
UserWarning,
),
],
)
def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"arr, shape, order, exc",
[
(
set_test_value(aet.lvector(), np.array([9, 15, 1], dtype="int64")),
aet.as_tensor([2, 3, 4]),
"C",
None,
),
(
set_test_value(aet.lvector(), np.array([1, 0], dtype="int64")),
aet.as_tensor([2]),
"C",
None,
),
(
set_test_value(aet.lvector(), np.array([9, 15, 1], dtype="int64")),
aet.as_tensor([2, 3, 4]),
"F",
NotImplementedError,
),
],
)
def test_UnravelIndex(arr, shape, order, exc):
g = extra_ops.UnravelIndex(order)(arr, shape)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"a, v, side, sorter, exc",
[
(
set_test_value(
aet.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)
),
set_test_value(
aet.matrix(), np.random.random((3, 2)).astype(config.floatX)
),
"left",
None,
None,
),
pytest.param(
set_test_value(
aet.vector(),
np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX),
),
set_test_value(
aet.matrix(),
np.array(
[
[0.18847123, 0.39659508],
[0.56220006, 0.57428752],
[0.86720994, 0.44522637],
]
).astype(config.floatX),
),
"left",
None,
None,
marks=pytest.mark.xfail(
reason="This won't work until https://github.com/numba/numba/pull/7005 is merged"
),
),
(
set_test_value(
aet.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)
),
set_test_value(
aet.matrix(), np.random.random((3, 2)).astype(config.floatX)
),
"right",
set_test_value(aet.lvector(), np.array([0, 2, 1])),
UserWarning,
),
],
)
def test_Searchsorted(a, v, side, sorter, exc):
g = extra_ops.SearchsortedOp(side)(a, v, sorter)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, shape, exc",
[
(
set_test_value(
aet.vector(), np.random.random(size=(2,)).astype(config.floatX)
),
[set_test_value(aet.lscalar(), np.array(v)) for v in [3, 2]],
UserWarning,
),
],
)
def test_BroadcastTo(x, shape, exc):
g = extra_ops.BroadcastTo()(x, shape)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论