提交 2b7f95cf authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in `assert_size_argument_jax_compatible`

上级 bb7d70fb
......@@ -45,8 +45,10 @@ def assert_size_argument_jax_compatible(node):
"""
size = node.inputs[1]
size_op = size.owner.op
if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)):
size_node = size.owner
if (size_node is not None) and (
not isinstance(size_node.op, (Shape, Shape_i, JAXShapeTuple))
):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
......
......@@ -693,6 +693,18 @@ def test_random_concrete_shape():
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(x_at, 1, rng=rng)
with pytest.warns(
UserWarning,
match="The RandomType SharedVariables \[.+\] will not be used"
):
jax_fn = function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
def test_random_concrete_shape_subtensor():
"""JAX should compile when a concrete value is passed for the `size` parameter.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论