提交 8a8c7e7d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Disable omnistaging by default

Doing this enables general use of `jax` >= 0.2.0; however, we still need to find a better fix for the shape-related problems introduced by omnistaging. Closes #43.
上级 0e4bf69b
...@@ -60,7 +60,7 @@ install: ...@@ -60,7 +60,7 @@ install:
- conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION - conda create --yes -q -n pyenv python=$TRAVIS_PYTHON_VERSION
- conda activate pyenv - conda activate pyenv
- conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu - conda install --yes -q mkl numpy scipy pip mkl-service graphviz cython libgpuarray pygpu
- if [[ "$TRAVIS_PYTHON_VERSION" != "3.6" ]]; then conda install --yes -q -c conda-forge 'jax<0.2.0' 'jaxlib'; fi - if [[ "$TRAVIS_PYTHON_VERSION" != "3.6" ]]; then conda install --yes -q -c conda-forge jax jaxlib; fi
- pip install -q -r requirements.txt - pip install -q -r requirements.txt
- conda list && pip freeze - conda list && pip freeze
- python -c 'import theano; print(theano.config.__str__(print_doc=False))' - python -c 'import theano; print(theano.config.__str__(print_doc=False))'
......
...@@ -10,5 +10,5 @@ coveralls ...@@ -10,5 +10,5 @@ coveralls
cython cython
sympy sympy
versioneer versioneer
jax<0.2.0; python_version > '3.6' jax; python_version > '3.6'
jaxlib; python_version > '3.6' jaxlib; python_version > '3.6'
...@@ -51,6 +51,10 @@ from theano.compile.ops import ( ...@@ -51,6 +51,10 @@ from theano.compile.ops import (
from theano.tensor.opt import MakeVector from theano.tensor.opt import MakeVector
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
jax.config.disable_omnistaging()
jax.config.update("jax_enable_x64", True) jax.config.update("jax_enable_x64", True)
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor) subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论