提交 e108fab5 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Fix the JAX implementation of `SpecifyShape`

上级 8cf3b20f
......@@ -94,10 +94,10 @@ def jax_funcify_Shape_i(op, **kwargs):
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs):
def jax_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
assert x.shape == tuple(shape), (
"got shape",
x.shape,
"expected",
......
import jax
import numpy as np
import pytest
from packaging.version import parse as version_parse
import pytensor.tensor as at
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -25,24 +23,21 @@ def test_jax_shape_ops():
compare_jax_and_py(x_fg, [], must_be_device_array=False)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_specify_shape():
x_np = np.zeros((20, 3))
x = SpecifyShape()(at.as_tensor_variable(x_np), (20, 3))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3))
x_fg = FunctionGraph([], [x])
with pytest.raises(AssertionError):
compare_jax_and_py(x_fg, [])
in_at = at.matrix("in")
x = at.specify_shape(in_at, (4, 5))
x_fg = FunctionGraph([in_at], [x])
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes
in_at = at.matrix("in")
shape_at = at.matrix("shape")
x = at.specify_shape(in_at, shape_at.shape)
x_fg = FunctionGraph([in_at, shape_at], [x])
compare_jax_and_py(
x_fg,
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
def test_jax_Reshape_constant():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论