提交 9023e2b3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix an issue with non-homogeneous shape tuples in Numba's BroadcastTo

上级 081967d3
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
import numba import numba
import numpy as np import numpy as np
from numba.misc.special import literal_unroll
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
from aesara import config from aesara import config
...@@ -367,10 +368,12 @@ def numba_funcify_BroadcastTo(op, node, **kwargs): ...@@ -367,10 +368,12 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
def broadcast_to(x, *shape): def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple() scalars_shape = create_zeros_tuple()
for i in range(len(shape)): i = 0
for s_i in literal_unroll(shape):
scalars_shape = numba_basic.tuple_setitem( scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(shape[i]) scalars_shape, i, numba_basic.to_scalar(s_i)
) )
i += 1
return np.broadcast_to(x, scalars_shape) return np.broadcast_to(x, scalars_shape)
......
...@@ -1924,6 +1924,10 @@ def test_Searchsorted(a, v, side, sorter, exc): ...@@ -1924,6 +1924,10 @@ def test_Searchsorted(a, v, side, sorter, exc):
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]), at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
), ),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
], ],
) )
def test_BroadcastTo(x, shape): def test_BroadcastTo(x, shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论