提交 1e63a48f authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Warn instead of raise when converting `CheckAndRaise`

上级 b35d6188
...@@ -10,7 +10,7 @@ from aesara.configdefaults import config ...@@ -10,7 +10,7 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.raise_op import CheckAndRaise from aesara.raise_op import Assert, CheckAndRaise
if config.floatX == "float64": if config.floatX == "float64":
...@@ -81,19 +81,18 @@ def jax_funcify_IfElse(op, **kwargs): ...@@ -81,19 +81,18 @@ def jax_funcify_IfElse(op, **kwargs):
return ifelse return ifelse
@jax_funcify.register(Assert)
@jax_funcify.register(CheckAndRaise) @jax_funcify.register(CheckAndRaise)
def jax_funcify_CheckAndRaise(op, **kwargs): def jax_funcify_CheckAndRaise(op, **kwargs):
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""",
stacklevel=2,
)
raise NotImplementedError( def assert_fn(x, *inputs):
f"""This exception is raised because you tried to convert an aesara graph with a `CheckAndRaise` Op (message: {op.msg}) to JAX. return x
JAX uses tracing to jit-compile functions, and assertions typically
don't do well with tracing. The appropriate workaround depends on what
you intended to do with the assertions in the first place.
Note that all assertions can be removed from the graph by adding return assert_fn
`local_remove_all_assert` to the rewrites."""
)
def jnp_safe_copy(x): def jnp_safe_copy(x):
......
...@@ -205,7 +205,6 @@ def test_jax_checkandraise(): ...@@ -205,7 +205,6 @@ def test_jax_checkandraise():
p.tag.test_value = 0 p.tag.test_value = 0
res = assert_op(p, p < 1.0) res = assert_op(p, p < 1.0)
res_fg = FunctionGraph([p], [res])
with pytest.raises(NotImplementedError): with pytest.warns(UserWarning):
compare_jax_and_py(res_fg, [1.0]) function((p,), res, mode=jax_mode)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论