提交 9023e2b3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix an issue with non-homogeneous shape tuples in Numba's BroadcastTo

上级 081967d3
......@@ -2,6 +2,7 @@ import warnings
import numba
import numpy as np
from numba.misc.special import literal_unroll
from numpy.core.multiarray import normalize_axis_index
from aesara import config
......@@ -367,10 +368,12 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple()
for i in range(len(shape)):
i = 0
for s_i in literal_unroll(shape):
scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(shape[i])
scalars_shape, i, numba_basic.to_scalar(s_i)
)
i += 1
return np.broadcast_to(x, scalars_shape)
......
......@@ -1924,6 +1924,10 @@ def test_Searchsorted(a, v, side, sorter, exc):
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论