提交 d0cb1374 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert shapes to scalar tuples in Numba BroadcastTo implementation

上级 a6b2b0cb
...@@ -358,8 +358,20 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ...@@ -358,8 +358,20 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
@numba_funcify.register(BroadcastTo) @numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs): def numba_funcify_BroadcastTo(op, node, **kwargs):
@numba_basic.numba_njit()
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, len(node.inputs) - 1
)
@numba_basic.numba_njit
def broadcast_to(x, *shape): def broadcast_to(x, *shape):
return np.broadcast_to(x, shape) scalars_shape = create_zeros_tuple()
for i in range(len(shape)):
scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(shape[i])
)
return np.broadcast_to(x, scalars_shape)
return broadcast_to return broadcast_to
...@@ -1857,6 +1857,14 @@ def test_Searchsorted(a, v, side, sorter, exc): ...@@ -1857,6 +1857,14 @@ def test_Searchsorted(a, v, side, sorter, exc):
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]], [set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]],
), ),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
),
], ],
) )
def test_BroadcastTo(x, shape): def test_BroadcastTo(x, shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论