提交 0b9bba1a authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Remove omnistaging warning

上级 1e63a48f
...@@ -18,19 +18,6 @@ if config.floatX == "float64": ...@@ -18,19 +18,6 @@ if config.floatX == "float64":
else: else:
jax.config.update("jax_enable_x64", False) jax.config.update("jax_enable_x64", False)
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
# Older versions < 0.2.0 do not have this flag so we don't need to set it.
try:
jax.config.disable_omnistaging()
except AttributeError:
pass
except Exception as e:
# The version might be >= 0.2.12, which means that omnistaging can't be
# disabled
warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
@singledispatch @singledispatch
def jax_typify(data, dtype=None, **kwargs): def jax_typify(data, dtype=None, **kwargs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论