提交 e108fab5 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Fix the JAX implementation of `SpecifyShape`

上级 8cf3b20f
...@@ -94,10 +94,10 @@ def jax_funcify_Shape_i(op, **kwargs): ...@@ -94,10 +94,10 @@ def jax_funcify_Shape_i(op, **kwargs):
@jax_funcify.register(SpecifyShape) @jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs): def jax_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape): def specifyshape(x, *shape):
assert x.ndim == len(shape) assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), ( assert x.shape == tuple(shape), (
"got shape", "got shape",
x.shape, x.shape,
"expected", "expected",
......
import jax
import numpy as np import numpy as np
import pytest import pytest
from packaging.version import parse as version_parse
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -25,24 +23,21 @@ def test_jax_shape_ops(): ...@@ -25,24 +23,21 @@ def test_jax_shape_ops():
compare_jax_and_py(x_fg, [], must_be_device_array=False) compare_jax_and_py(x_fg, [], must_be_device_array=False)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_specify_shape(): def test_jax_specify_shape():
x_np = np.zeros((20, 3)) in_at = at.matrix("in")
x = SpecifyShape()(at.as_tensor_variable(x_np), (20, 3)) x = at.specify_shape(in_at, (4, 5))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([in_at], [x])
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
compare_jax_and_py(x_fg, [])
# When used to assert two arrays have similar shapes
with config.change_flags(compute_test_value="off"): in_at = at.matrix("in")
shape_at = at.matrix("shape")
x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3)) x = at.specify_shape(in_at, shape_at.shape)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([in_at, shape_at], [x])
compare_jax_and_py(
with pytest.raises(AssertionError): x_fg,
compare_jax_and_py(x_fg, []) [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
def test_jax_Reshape_constant(): def test_jax_Reshape_constant():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论