提交 92eef5ed authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow running JAX functions with scalar inputs for RV shapes

上级 4cdd2905
...@@ -9,8 +9,13 @@ from pytensor.link.basic import JITLinker ...@@ -9,8 +9,13 @@ from pytensor.link.basic import JITLinker
class JAXLinker(JITLinker): class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.""" """A `Linker` that JIT-compiles NumPy-based operations using JAX."""
def __init__(self, *args, **kwargs):
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
super().__init__(*args, **kwargs)
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.type import RandomType
shared_rng_inputs = [ shared_rng_inputs = [
...@@ -64,6 +69,23 @@ class JAXLinker(JITLinker): ...@@ -64,6 +69,23 @@ class JAXLinker(JITLinker):
fgraph.inputs.remove(new_inp) fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgrap_index, new_inp) fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
fgraph_inputs = fgraph.inputs
clients = fgraph.clients
# Detect scalar shape inputs that are used only in JAXShapeTuple nodes
scalar_shape_inputs = [
inp
for node in fgraph.apply_nodes
if isinstance(node.op, JAXShapeTuple)
for inp in node.inputs
if inp in fgraph_inputs
and all(
isinstance(cl_node.op, JAXShapeTuple) for cl_node, _ in clients[inp]
)
]
self.scalar_shape_inputs = tuple(
fgraph_inputs.index(inp) for inp in scalar_shape_inputs
)
return jax_funcify( return jax_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
) )
...@@ -71,7 +93,22 @@ class JAXLinker(JITLinker): ...@@ -71,7 +93,22 @@ class JAXLinker(JITLinker):
def jit_compile(self, fn): def jit_compile(self, fn):
import jax import jax
return jax.jit(fn) jit_fn = jax.jit(fn, static_argnums=self.scalar_shape_inputs)
if not self.scalar_shape_inputs:
return jit_fn
def convert_scalar_shape_inputs(
*args, scalar_shape_inputs=set(self.scalar_shape_inputs)
):
return jit_fn(
*(
int(arg) if i in scalar_shape_inputs else arg
for i, arg in enumerate(args)
)
)
return convert_scalar_shape_inputs
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
from pytensor.link.jax.dispatch import jax_typify from pytensor.link.jax.dispatch import jax_typify
......
...@@ -894,15 +894,55 @@ class TestRandomShapeInputs: ...@@ -894,15 +894,55 @@ class TestRandomShapeInputs:
jax_fn = compile_random_function([x_pt], out) jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,) assert jax_fn(np.ones((2, 3))).shape == (2,)
def test_random_scalar_shape_input(self):
dim0 = pt.scalar("dim0", dtype=int)
dim1 = pt.scalar("dim1", dtype=int)
out = pt.random.normal(0, 1, size=dim0)
jax_fn = compile_random_function([dim0], out)
assert jax_fn(np.array(2)).shape == (2,)
assert jax_fn(np.array(3)).shape == (3,)
out = pt.random.normal(0, 1, size=[dim0, dim1])
jax_fn = compile_random_function([dim0, dim1], out)
assert jax_fn(np.array(2), np.array(3)).shape == (2, 3)
assert jax_fn(np.array(4), np.array(5)).shape == (4, 5)
@pytest.mark.xfail( @pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True raises=TypeError, reason="Cannot convert scalar input to integer"
) )
def test_random_concrete_shape_graph_input(self): def test_random_scalar_shape_input_not_supported(self):
rng = shared(np.random.default_rng(123)) dim = pt.scalar("dim", dtype=int)
size_pt = pt.scalar() out1 = pt.random.normal(0, 1, size=dim)
out = pt.random.normal(0, 1, size=size_pt, rng=rng) # An operation that wouldn't work if we replaced 0d array by integer
jax_fn = compile_random_function([size_pt], out) out2 = dim[...].set(1)
assert jax_fn(10).shape == (10,) jax_fn = compile_random_function([dim], [out1, out2])
res1, res2 = jax_fn(np.array(2))
assert res1.shape == (2,)
assert res2 == 1
@pytest.mark.xfail(
raises=TypeError, reason="Cannot convert scalar input to integer"
)
def test_random_scalar_shape_input_not_supported2(self):
dim = pt.scalar("dim", dtype=int)
# This could theoretically be supported
# but would require knowing that * 2 is a safe operation for a python integer
out = pt.random.normal(0, 1, size=dim * 2)
jax_fn = compile_random_function([dim], out)
assert jax_fn(np.array(2)).shape == (4,)
@pytest.mark.xfail(
raises=TypeError, reason="Cannot convert tensor input to shape tuple"
)
def test_random_vector_shape_graph_input(self):
shape = pt.vector("shape", shape=(2,), dtype=int)
out = pt.random.normal(0, 1, size=shape)
jax_fn = compile_random_function([shape], out)
assert jax_fn(np.array([2, 3])).shape == (2, 3)
assert jax_fn(np.array([4, 5])).shape == (4, 5)
def test_constant_shape_after_graph_rewriting(self): def test_constant_shape_after_graph_rewriting(self):
size = pt.vector("size", shape=(2,), dtype=int) size = pt.vector("size", shape=(2,), dtype=int)
...@@ -912,13 +952,13 @@ class TestRandomShapeInputs: ...@@ -912,13 +952,13 @@ class TestRandomShapeInputs:
with pytest.raises(TypeError): with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5]) compile_random_function([size], x)([2, 5])
# Rebuild with strict=False so output type is not updated # Rebuild with strict=True so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None) assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5) assert compile_random_function([], new_x)().shape == (2, 5)
# Rebuild with strict=True, so output type is updated # Rebuild with strict=False, so output type is updated
# This uses a different path in the dispatch implementation # This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5) assert new_x.type.shape == (2, 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论