提交 1e63a48f authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Warn instead of raise when converting `CheckAndRaise`

上级 b35d6188
......@@ -10,7 +10,7 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python
from aesara.raise_op import CheckAndRaise
from aesara.raise_op import Assert, CheckAndRaise
if config.floatX == "float64":
......@@ -81,19 +81,18 @@ def jax_funcify_IfElse(op, **kwargs):
return ifelse
@jax_funcify.register(Assert)
@jax_funcify.register(CheckAndRaise)
def jax_funcify_CheckAndRaise(op, **kwargs):
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""",
stacklevel=2,
)
raise NotImplementedError(
f"""This exception is raised because you tried to convert an aesara graph with a `CheckAndRaise` Op (message: {op.msg}) to JAX.
JAX uses tracing to jit-compile functions, and assertions typically
don't do well with tracing. The appropriate workaround depends on what
you intended to do with the assertions in the first place.
def assert_fn(x, *inputs):
return x
Note that all assertions can be removed from the graph by adding
`local_remove_all_assert` to the rewrites."""
)
return assert_fn
def jnp_safe_copy(x):
......
......@@ -205,7 +205,6 @@ def test_jax_checkandraise():
p.tag.test_value = 0
res = assert_op(p, p < 1.0)
res_fg = FunctionGraph([p], [res])
with pytest.raises(NotImplementedError):
compare_jax_and_py(res_fg, [1.0])
with pytest.warns(UserWarning):
function((p,), res, mode=jax_mode)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论