提交 4cdd2905 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Group JAX random shape input tests

上级 b26cc8bf
......@@ -836,94 +836,90 @@ def test_random_custom_implementation():
compare_jax_and_py([], [out], [])
def test_random_concrete_shape():
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_from_param():
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_subtensor():
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
slight improvement over their API.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (3,)
def test_random_concrete_shape_subtensor_tuple():
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,)
@pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out)
assert jax_fn(10).shape == (10,)
def test_constant_shape_after_graph_rewriting():
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)
with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5])
# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5)
# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)
class TestRandomShapeInputs:
def test_random_concrete_shape(self):
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time;
2. The shape of an array.
3. `static_argnums` parameters
This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_from_param(self):
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_subtensor(self):
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
slight improvement over their API.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (3,)
def test_random_concrete_shape_subtensor_tuple(self):
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite.
"""
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,)
@pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input(self):
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out)
assert jax_fn(10).shape == (10,)
def test_constant_shape_after_graph_rewriting(self):
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)
with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5])
# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5)
# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论