提交 2f0b424b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Test on `numba>=0.57`

上级 f0fda414
...@@ -139,7 +139,7 @@ jobs: ...@@ -139,7 +139,7 @@ jobs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
pip install -e ./ pip install -e ./
mamba list && pip freeze mamba list && pip freeze
...@@ -192,7 +192,7 @@ jobs: ...@@ -192,7 +192,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: bash -l {0} shell: bash -l {0}
run: | run: |
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.55" numba-scipy jax jaxlib pytest-benchmark mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" numba-scipy jax jaxlib pytest-benchmark
pip install -e ./ pip install -e ./
mamba list && pip freeze mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
......
...@@ -22,7 +22,7 @@ dependencies: ...@@ -22,7 +22,7 @@ dependencies:
- mkl-service - mkl-service
- libblas=*=*mkl - libblas=*=*mkl
# numba backend # numba backend
- numba>=0.55 - numba>=0.57
- numba-scipy - numba-scipy
# For testing # For testing
- coveralls - coveralls
......
...@@ -312,6 +312,7 @@ def numba_funcify_BernoulliRV(op, node, **kwargs): ...@@ -312,6 +312,7 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
def numba_funcify_CategoricalRV(op, node, **kwargs): def numba_funcify_CategoricalRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(node.inputs[1])) size_len = int(get_vector_length(node.inputs[1]))
p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit @numba_basic.numba_njit
def categorical_rv(rng, size, dtype, p): def categorical_rv(rng, size, dtype, p):
...@@ -321,6 +322,10 @@ def numba_funcify_CategoricalRV(op, node, **kwargs): ...@@ -321,6 +322,10 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p = np.broadcast_to(p, size_tpl + p.shape[-1:]) p = np.broadcast_to(p, size_tpl + p.shape[-1:])
# Workaround https://github.com/numba/numba/issues/8975
if not size_len and p_ndim == 1:
unif_samples = np.asarray(np.random.uniform(0, 1))
else:
unif_samples = np.random.uniform(0, 1, size_tpl) unif_samples = np.random.uniform(0, 1, size_tpl)
res = np.empty(size_tpl, dtype=out_dtype) res = np.empty(size_tpl, dtype=out_dtype)
......
...@@ -530,9 +530,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -530,9 +530,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
at.as_tensor(rng.poisson(size=(2, 5))), at.as_tensor(rng.poisson(size=(2, 5))),
([1, 1], [2, 2]), ([1, 1], [2, 2]),
marks=pytest.mark.xfail(
reason="Duplicate index handling hasn't been implemented, yet."
),
), ),
], ],
) )
......
...@@ -459,9 +459,6 @@ def test_UnravelIndex(arr, shape, order, exc): ...@@ -459,9 +459,6 @@ def test_UnravelIndex(arr, shape, order, exc):
"left", "left",
None, None,
None, None,
marks=pytest.mark.xfail(
reason="This won't work until https://github.com/numba/numba/pull/7005 is merged"
),
), ),
( (
set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论