提交 4cdd2905 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Group JAX random shape input tests

上级 b26cc8bf
......@@ -836,7 +836,8 @@ def test_random_custom_implementation():
compare_jax_and_py([], [out], [])
def test_random_concrete_shape():
class TestRandomShapeInputs:
def test_random_concrete_shape(self):
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
There are three quantities that JAX considers as concrete:
......@@ -853,16 +854,14 @@ def test_random_concrete_shape():
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_from_param():
def test_random_concrete_shape_from_param(self):
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_subtensor():
def test_random_concrete_shape_subtensor(self):
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
......@@ -880,8 +879,7 @@ def test_random_concrete_shape_subtensor():
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (3,)
def test_random_concrete_shape_subtensor_tuple():
def test_random_concrete_shape_subtensor_tuple(self):
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
......@@ -896,19 +894,17 @@ def test_random_concrete_shape_subtensor_tuple():
jax_fn = compile_random_function([x_pt], out)
assert jax_fn(np.ones((2, 3))).shape == (2,)
@pytest.mark.xfail(
@pytest.mark.xfail(
reason="`size_pt` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
)
def test_random_concrete_shape_graph_input(self):
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out)
assert jax_fn(10).shape == (10,)
def test_constant_shape_after_graph_rewriting():
def test_constant_shape_after_graph_rewriting(self):
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论