提交 4e2fc615 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Add JAX conversion for CheckAndRaiseOp

上级 e69f2d0a
...@@ -14,6 +14,7 @@ from aesara.configdefaults import config ...@@ -14,6 +14,7 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.raise_op import CheckAndRaise
from aesara.scalar import Softplus from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv, Psi from aesara.scalar.math import Erf, Erfc, Erfinv, Psi
...@@ -575,6 +576,21 @@ def jax_funcify_IfElse(op, **kwargs): ...@@ -575,6 +576,21 @@ def jax_funcify_IfElse(op, **kwargs):
return ifelse return ifelse
@jax_funcify.register(CheckAndRaise)
def jax_funcify_CheckAndRaise(op, **kwargs):
raise NotImplementedError(
f"""This exception is raised because you tried to convert an aesara graph with a `CheckAndRaise` Op (message: {op.msg}) to JAX.
JAX uses tracing to jit-compile functions, and assertions typically
don't do well with tracing. The appropriate workaround depends on what
you intended to do with the assertions in the first place.
Note that all assertions can be removed from the graph by adding
`local_remove_all_assert` to the rewrites."""
)
@jax_funcify.register(Subtensor) @jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op, **kwargs): def jax_funcify_Subtensor(op, **kwargs):
......
...@@ -17,6 +17,7 @@ from aesara.graph.op import Op, get_test_value ...@@ -17,6 +17,7 @@ from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.ifelse import ifelse from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker from aesara.link.jax import JAXLinker
from aesara.raise_op import assert_op
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.tensor import basic as at from aesara.tensor import basic as at
...@@ -770,6 +771,17 @@ def test_jax_ifelse(): ...@@ -770,6 +771,17 @@ def test_jax_ifelse():
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
def test_jax_checkandraise():
p = scalar()
p.tag.test_value = 0
res = assert_op(p, p < 1.0)
res_fg = FunctionGraph([p], [res])
with pytest.raises(NotImplementedError):
compare_jax_and_py(res_fg, [1.0])
def test_jax_CAReduce(): def test_jax_CAReduce():
a_at = vector("a") a_at = vector("a")
a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论