提交 97bd1ac6 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Use constants or raise in JAX `Arange` implementation

上级 a824aee4
......@@ -3,6 +3,7 @@ import warnings
import jax.numpy as jnp
import numpy as np
from pytensor.graph.basic import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
......@@ -22,6 +23,15 @@ from pytensor.tensor.basic import (
from pytensor.tensor.exceptions import NotScalarConstantError
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
to be constants. The graph that you defined thus cannot be JIT-compiled
by JAX. An example of a graph that can be compiled to JAX:
>>> import pytensor.tensor basic
>>> at.arange(1, 10, 2)
"""
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset
......@@ -50,9 +60,26 @@ def jax_funcify_Alloc(op, **kwargs):
@jax_funcify.register(ARange)
def jax_funcify_ARange(op, **kwargs):
# XXX: This currently requires concrete arguments.
def arange(start, stop, step):
def jax_funcify_ARange(op, node, **kwargs):
"""Register a JAX implementation for `ARange`.
`jax.numpy.arange` requires concrete values for its arguments. Here we check
that the arguments are constant, and raise otherwise.
TODO: Handle other situations in which values are concrete (shape of an array).
"""
arange_args = node.inputs
constant_args = []
for arg in arange_args:
if not isinstance(arg, Constant):
raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)
constant_args.append(arg.value)
start, stop, step = constant_args
def arange(*_):
return jnp.arange(start, stop, step, dtype=op.dtype)
return arange
......
......@@ -54,13 +54,20 @@ def test_jax_MakeVector():
compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange():
out = at.arange(1, 10, 2)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a = scalar("a")
a.tag.test_value = 10
out = at.arange(a)
with pytest.raises(NotImplementedError):
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论