提交 97bd1ac6 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Use constants or raise in JAX `Arange` implementation

上级 a824aee4
...@@ -3,6 +3,7 @@ import warnings ...@@ -3,6 +3,7 @@ import warnings
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from pytensor.graph.basic import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -22,6 +23,15 @@ from pytensor.tensor.basic import ( ...@@ -22,6 +23,15 @@ from pytensor.tensor.basic import (
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
to be constants. The graph that you defined thus cannot be JIT-compiled
by JAX. An example of a graph that can be compiled to JAX:
>>> import pytensor.tensor basic
>>> at.arange(1, 10, 2)
"""
@jax_funcify.register(AllocDiag) @jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op, **kwargs): def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset offset = op.offset
...@@ -50,9 +60,26 @@ def jax_funcify_Alloc(op, **kwargs): ...@@ -50,9 +60,26 @@ def jax_funcify_Alloc(op, **kwargs):
@jax_funcify.register(ARange) @jax_funcify.register(ARange)
def jax_funcify_ARange(op, **kwargs): def jax_funcify_ARange(op, node, **kwargs):
# XXX: This currently requires concrete arguments. """Register a JAX implementation for `ARange`.
def arange(start, stop, step):
`jax.numpy.arange` requires concrete values for its arguments. Here we check
that the arguments are constant, and raise otherwise.
TODO: Handle other situations in which values are concrete (shape of an array).
"""
arange_args = node.inputs
constant_args = []
for arg in arange_args:
if not isinstance(arg, Constant):
raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)
constant_args.append(arg.value)
start, stop, step = constant_args
def arange(*_):
return jnp.arange(start, stop, step, dtype=op.dtype) return jnp.arange(start, stop, step, dtype=op.dtype)
return arange return arange
......
...@@ -54,13 +54,20 @@ def test_jax_MakeVector(): ...@@ -54,13 +54,20 @@ def test_jax_MakeVector():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") def test_arange():
out = at.arange(1, 10, 2)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
def test_arange_nonconcrete(): def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a = scalar("a") a = scalar("a")
a.tag.test_value = 10 a.tag.test_value = 10
out = at.arange(a) out = at.arange(a)
with pytest.raises(NotImplementedError):
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论