提交 4cdd2905 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Group JAX random shape input tests

上级 b26cc8bf
...@@ -836,94 +836,90 @@ def test_random_custom_implementation(): ...@@ -836,94 +836,90 @@ def test_random_custom_implementation():
compare_jax_and_py([], [out], []) compare_jax_and_py([], [out], [])
def test_random_concrete_shape(): class TestRandomShapeInputs:
"""JAX should compile when a `RandomVariable` is passed a concrete shape. def test_random_concrete_shape(self):
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
1. Constants known at compile time; There are three quantities that JAX considers as concrete:
2. The shape of an array. 1. Constants known at compile time;
3. `static_argnums` parameters 2. The shape of an array.
This test makes sure that graphs with `RandomVariable`s compile when the 3. `static_argnums` parameters
`size` parameter satisfies either of these criteria. This test makes sure that graphs with `RandomVariable`s compile when the
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.default_rng(123)) """
x_pt = pt.dmatrix() rng = shared(np.random.default_rng(123))
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) x_pt = pt.dmatrix()
jax_fn = compile_random_function([x_pt], out) out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
assert jax_fn(np.ones((2, 3))).shape == (2, 3) jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_from_param(): def test_random_concrete_shape_from_param(self):
rng = shared(np.random.default_rng(123)) rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng) out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out) jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3) assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_subtensor(self):
def test_random_concrete_shape_subtensor(): """JAX should compile when a concrete value is passed for the `size` parameter.
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple` rewrite.
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
JAX does not accept scalars as `size` or `shape` arguments, so this is a slight improvement over their API.
slight improvement over their API.
"""
""" rng = shared(np.random.default_rng(123))
rng = shared(np.random.default_rng(123)) x_pt = pt.dmatrix()
x_pt = pt.dmatrix() out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) jax_fn = compile_random_function([x_pt], out)
jax_fn = compile_random_function([x_pt], out) assert jax_fn(np.ones((2, 3))).shape == (3,)
assert jax_fn(np.ones((2, 3))).shape == (3,)
def test_random_concrete_shape_subtensor_tuple(self):
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
def test_random_concrete_shape_subtensor_tuple():
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter. This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple scalar inputs into tuples of concrete values using the
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete `jax_size_parameter_as_tuple` rewrite.
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite. """
rng = shared(np.random.default_rng(123))
""" x_pt = pt.dmatrix()
rng = shared(np.random.default_rng(123)) out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
x_pt = pt.dmatrix() jax_fn = compile_random_function([x_pt], out)
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) assert jax_fn(np.ones((2, 3))).shape == (2,)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,) @pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
)
@pytest.mark.xfail( def test_random_concrete_shape_graph_input(self):
reason="`size_pt` should be specified as a static argument", strict=True rng = shared(np.random.default_rng(123))
) size_pt = pt.scalar()
def test_random_concrete_shape_graph_input(): out = pt.random.normal(0, 1, size=size_pt, rng=rng)
rng = shared(np.random.default_rng(123)) jax_fn = compile_random_function([size_pt], out)
size_pt = pt.scalar() assert jax_fn(10).shape == (10,)
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out) def test_constant_shape_after_graph_rewriting(self):
assert jax_fn(10).shape == (10,) size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)
def test_constant_shape_after_graph_rewriting():
size = pt.vector("size", shape=(2,), dtype=int) with pytest.raises(TypeError):
x = pt.random.normal(size=size) compile_random_function([size], x)([2, 5])
assert x.type.shape == (None, None)
# Rebuild with strict=False so output type is not updated
with pytest.raises(TypeError): # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
compile_random_function([size], x)([2, 5]) new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
# Rebuild with strict=False so output type is not updated assert compile_random_function([], new_x)().shape == (2, 5)
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) # Rebuild with strict=True, so output type is updated
assert new_x.type.shape == (None, None) # This uses a different path in the dispatch implementation
assert compile_random_function([], new_x)().shape == (2, 5) new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
# Rebuild with strict=True, so output type is updated assert compile_random_function([], new_x)().shape == (2, 5)
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论