Unverified 提交 236a3dfe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix failing CI in Python 3.8 and install numba/jax on specific runs (#326)

* Add failed assertion message in CI * Pin numpy upper bound in numba install numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. * Run numba and jax tests in separate jobs --------- Co-authored-by: 's avatarBen Mares <services-git-throwaway1@tensorial.com>
上级 a8e0adcb
......@@ -73,7 +73,8 @@ jobs:
python-version: ["3.8", "3.11"]
fast-compile: [0,1]
float32: [0,1]
install-numba: [1]
install-numba: [0]
install-jax: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
......@@ -93,6 +94,27 @@ jobs:
part: "tests/tensor/test_math.py"
- fast-compile: 1
float32: 1
include:
- install-numba: 1
python-version: "3.8"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-numba: 1
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-jax: 1
python-version: "3.8"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-jax: 1
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/jax"
steps:
- uses: actions/checkout@v3
with:
......@@ -118,15 +140,20 @@ jobs:
shell: bash -l {0}
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
# PyTensor next, pip installs a lower version of numpy via the PyPI.
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")'
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
- name: Run tests
shell: bash -l {0}
......@@ -175,7 +202,7 @@ jobs:
pip install -e ./
mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")'
python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env:
PYTHON_VERSION: 3.9
- name: Download previous benchmark data
......
......@@ -9,6 +9,8 @@ per-file-ignores =
pytensor/link/jax/jax_dispatch.py:E402,F403,F401
pytensor/link/jax/jax_linker.py:E402,F403,F401
pytensor/sparse/sandbox/sp2.py:F401
tests/link/jax/*.py:E402
tests/link/numba/*.py:E402
tests/tensor/test_math_scipy.py:E402
tests/sparse/test_basic.py:E402
tests/sparse/test_opt.py:E402
......
import jax.errors
import numpy as np
import pytest
jax = pytest.importorskip("jax")
import jax.errors
import pytensor
import pytensor.tensor.basic as at
from pytensor.configdefaults import config
......
......@@ -3,10 +3,12 @@ import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock
import numba
import numpy as np
import pytest
numba = pytest.importorskip("numba")
import pytensor.scalar as aes
import pytensor.scalar.math as aesm
import pytensor.tensor as at
......
import numpy as np
import pytest
import scipy.special.cython_special
numba = pytest.importorskip("numba")
from numba.types import float32, float64, int32, int64
from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function
......
......@@ -3,6 +3,9 @@ import timeit
import numpy as np
import pytest
pytest.importorskip("numba")
import pytensor.tensor as aet
from pytensor import config
from pytensor.compile.function import function
......@@ -70,4 +73,5 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
mean_numpy_time = np.mean(numpy_times)
# mean_c_time = np.mean(c_times)
# FIXME: Why are we asserting >=? Numba could be doing worse than numpy!
assert mean_numba_time / mean_numpy_time >= 0.75
import numba
import numpy as np
import pytest
import scipy as sp
numba = pytest.importorskip("numba")
# Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论