提交 2f0b424b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test on `numba>=0.57`

上级 f0fda414
......@@ -139,7 +139,7 @@ jobs:
shell: bash -l {0}
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
pip install -e ./
mamba list && pip freeze
......@@ -192,7 +192,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.55" numba-scipy jax jaxlib pytest-benchmark
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" numba-scipy jax jaxlib pytest-benchmark
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
......
......@@ -22,7 +22,7 @@ dependencies:
- mkl-service
- libblas=*=*mkl
# numba backend
- numba>=0.55
- numba>=0.57
- numba-scipy
# For testing
- coveralls
......
......@@ -312,6 +312,7 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
def numba_funcify_CategoricalRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(node.inputs[1]))
p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit
def categorical_rv(rng, size, dtype, p):
......@@ -321,7 +322,11 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
unif_samples = np.random.uniform(0, 1, size_tpl)
# Workaround https://github.com/numba/numba/issues/8975
if not size_len and p_ndim == 1:
unif_samples = np.asarray(np.random.uniform(0, 1))
else:
unif_samples = np.random.uniform(0, 1, size_tpl)
res = np.empty(size_tpl, dtype=out_dtype)
for idx in np.ndindex(*size_tpl):
......
......@@ -530,9 +530,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
at.as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]),
marks=pytest.mark.xfail(
reason="Duplicate index handling hasn't been implemented, yet."
),
),
],
)
......
......@@ -459,9 +459,6 @@ def test_UnravelIndex(arr, shape, order, exc):
"left",
None,
None,
marks=pytest.mark.xfail(
reason="This won't work until https://github.com/numba/numba/pull/7005 is merged"
),
),
(
set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论