提交 28fc9acb authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not raise early when a Shape operation is an input to Arange in the JAX backend

上级 71c58f39
......@@ -21,6 +21,7 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
......@@ -61,14 +62,20 @@ def jax_funcify_ARange(op, node, **kwargs):
arange_args = node.inputs
constant_args = []
for arg in arange_args:
if not isinstance(arg, Constant):
if arg.owner and isinstance(arg.owner.op, Shape_i):
constant_args.append(None)
elif isinstance(arg, Constant):
constant_args.append(arg.value)
else:
# TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)!
raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR)
constant_args.append(arg.value)
start, stop, step = constant_args
constant_start, constant_stop, constant_step = constant_args
def arange(*_):
def arange(start, stop, step):
start = start if constant_start is None else constant_start
stop = stop if constant_stop is None else constant_stop
step = step if constant_step is None else constant_step
return jnp.arange(start, stop, step, dtype=op.dtype)
return arange
......
......@@ -85,7 +85,7 @@ def test_jax_basic():
],
)
out = at.diag(at.specify_shape(b, shape=(10,)))
out = at.diag(b)
out_fg = FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
......
......@@ -63,6 +63,13 @@ def test_arange():
compare_jax_and_py(fgraph, [])
def test_arange_of_shape():
x = vector("x")
out = at.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [np.zeros((5,))])
def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论