Unverified 提交 236a3dfe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix failing CI in Python 3.8 and install numba/jax on specific runs (#326)

* Add failed assertion message in CI * Pin numpy upper bound in numba install numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but not numpy, even though scipy 1.7 requires numpy<1.23. When installing PyTensor next, pip installs a lower version of numpy via the PyPI. * Run numba and jax tests in separate jobs --------- Co-authored-by: 's avatarBen Mares <services-git-throwaway1@tensorial.com>
上级 a8e0adcb
...@@ -73,7 +73,8 @@ jobs: ...@@ -73,7 +73,8 @@ jobs:
python-version: ["3.8", "3.11"] python-version: ["3.8", "3.11"]
fast-compile: [0,1] fast-compile: [0,1]
float32: [0,1] float32: [0,1]
install-numba: [1] install-numba: [0]
install-jax: [0]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan" - "tests/scan"
...@@ -93,6 +94,27 @@ jobs: ...@@ -93,6 +94,27 @@ jobs:
part: "tests/tensor/test_math.py" part: "tests/tensor/test_math.py"
- fast-compile: 1 - fast-compile: 1
float32: 1 float32: 1
include:
- install-numba: 1
python-version: "3.8"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-numba: 1
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/numba"
- install-jax: 1
python-version: "3.8"
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-jax: 1
python-version: "3.11"
fast-compile: 0
float32: 0
part: "tests/link/jax"
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
...@@ -118,15 +140,20 @@ jobs: ...@@ -118,15 +140,20 @@ jobs:
shell: bash -l {0} shell: bash -l {0}
run: | run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.8, but
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro # not numpy, even though scipy 1.7 requires numpy<1.23. When installing
# PyTensor next, pip installs a lower version of numpy via the PyPI.
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.8" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
pip install -e ./ pip install -e ./
mamba list && pip freeze mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")' python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env: env:
PYTHON_VERSION: ${{ matrix.python-version }} PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
- name: Run tests - name: Run tests
shell: bash -l {0} shell: bash -l {0}
...@@ -175,7 +202,7 @@ jobs: ...@@ -175,7 +202,7 @@ jobs:
pip install -e ./ pip install -e ./
mamba list && pip freeze mamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
python -c 'import pytensor; assert(pytensor.config.blas__ldflags != "")' python -c 'import pytensor; assert pytensor.config.blas__ldflags != "", "Blas flags are empty"'
env: env:
PYTHON_VERSION: 3.9 PYTHON_VERSION: 3.9
- name: Download previous benchmark data - name: Download previous benchmark data
......
...@@ -9,6 +9,8 @@ per-file-ignores = ...@@ -9,6 +9,8 @@ per-file-ignores =
pytensor/link/jax/jax_dispatch.py:E402,F403,F401 pytensor/link/jax/jax_dispatch.py:E402,F403,F401
pytensor/link/jax/jax_linker.py:E402,F403,F401 pytensor/link/jax/jax_linker.py:E402,F403,F401
pytensor/sparse/sandbox/sp2.py:F401 pytensor/sparse/sandbox/sp2.py:F401
tests/link/jax/*.py:E402
tests/link/numba/*.py:E402
tests/tensor/test_math_scipy.py:E402 tests/tensor/test_math_scipy.py:E402
tests/sparse/test_basic.py:E402 tests/sparse/test_basic.py:E402
tests/sparse/test_opt.py:E402 tests/sparse/test_opt.py:E402
......
import jax.errors
import numpy as np import numpy as np
import pytest import pytest
jax = pytest.importorskip("jax")
import jax.errors
import pytensor import pytensor
import pytensor.tensor.basic as at import pytensor.tensor.basic as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
......
...@@ -3,10 +3,12 @@ import inspect ...@@ -3,10 +3,12 @@ import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock from unittest import mock
import numba
import numpy as np import numpy as np
import pytest import pytest
numba = pytest.importorskip("numba")
import pytensor.scalar as aes import pytensor.scalar as aes
import pytensor.scalar.math as aesm import pytensor.scalar.math as aesm
import pytensor.tensor as at import pytensor.tensor as at
......
import numpy as np import numpy as np
import pytest import pytest
import scipy.special.cython_special import scipy.special.cython_special
numba = pytest.importorskip("numba")
from numba.types import float32, float64, int32, int64 from numba.types import float32, float64, int32, int64
from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function
......
...@@ -3,6 +3,9 @@ import timeit ...@@ -3,6 +3,9 @@ import timeit
import numpy as np import numpy as np
import pytest import pytest
pytest.importorskip("numba")
import pytensor.tensor as aet import pytensor.tensor as aet
from pytensor import config from pytensor import config
from pytensor.compile.function import function from pytensor.compile.function import function
...@@ -70,4 +73,5 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals): ...@@ -70,4 +73,5 @@ def test_careduce_performance(careduce_fn, numpy_fn, axis, inputs, input_vals):
mean_numpy_time = np.mean(numpy_times) mean_numpy_time = np.mean(numpy_times)
# mean_c_time = np.mean(c_times) # mean_c_time = np.mean(c_times)
# FIXME: Why are we asserting >=? Numba could be doing worse than numpy!
assert mean_numba_time / mean_numpy_time >= 0.75 assert mean_numba_time / mean_numpy_time >= 0.75
import numba
import numpy as np import numpy as np
import pytest import pytest
import scipy as sp import scipy as sp
numba = pytest.importorskip("numba")
# Make sure the Numba customizations are loaded # Make sure the Numba customizations are loaded
import pytensor.link.numba.dispatch.sparse # noqa: F401 import pytensor.link.numba.dispatch.sparse # noqa: F401
from pytensor import config from pytensor import config
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论