Unverified 提交 d5122713 authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Fix JAX implementation of Argmax (#809)

上级 31bf6822
import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
......@@ -137,12 +138,10 @@ def jax_funcify_Argmax(op, **kwargs):
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
......@@ -151,9 +150,9 @@ def jax_funcify_Argmax(op, **kwargs):
# Otherwise reshape would complain citing float arg
new_shape = (
*kept_shape,
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
reshaped_x = transposed_x.reshape(tuple(new_shape))
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
......
......@@ -65,9 +65,9 @@ def test_extra_ops():
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
reason="JAX Numpy API does not support dynamic shapes",
)
def test_extra_ops_omni():
def test_extra_ops_dynamic_shapes():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
......
import numpy as np
import pytest
from packaging.version import parse as version_parse
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
......@@ -80,11 +79,7 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_basic_multiout_omni():
def test_jax_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
......@@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni():
compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论