Unverified 提交 6fe9f839 authored 作者: Kaustubh's avatar Kaustubh 提交者: GitHub

Added Numba implementation for BroadcastTo Op (#769)

上级 7c1558ad
...@@ -9,6 +9,7 @@ from aesara.link.numba.dispatch import basic as numba_basic ...@@ -9,6 +9,7 @@ from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import get_numba_type, numba_funcify from aesara.link.numba.dispatch.basic import get_numba_type, numba_funcify
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo,
CumOp, CumOp,
DiffOp, DiffOp,
FillDiagonal, FillDiagonal,
...@@ -353,3 +354,12 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ...@@ -353,3 +354,12 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
return np.searchsorted(a, v, side) return np.searchsorted(a, v, side)
return searchsorted return searchsorted
@numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs):
@numba_basic.numba_njit()
def broadcast_to(x, *shape):
return np.broadcast_to(x, shape)
return broadcast_to
...@@ -1813,29 +1813,26 @@ def test_Searchsorted(a, v, side, sorter, exc): ...@@ -1813,29 +1813,26 @@ def test_Searchsorted(a, v, side, sorter, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, shape, exc", "x, shape",
[ [
( (
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]], [set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]],
UserWarning,
), ),
], ],
) )
def test_BroadcastTo(x, shape, exc): def test_BroadcastTo(x, shape):
g = extra_ops.BroadcastTo()(x, shape) g = extra_ops.BroadcastTo()(x, shape)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) compare_numba_and_py(
with cm: g_fg,
compare_numba_and_py( [
g_fg, i.tag.test_value
[ for i in g_fg.inputs
i.tag.test_value if not isinstance(i, (SharedVariable, Constant))
for i in g_fg.inputs ],
if not isinstance(i, (SharedVariable, Constant)) )
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论