Unverified 提交 d5122713 authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Fix JAX implementation of Argmax (#809)

上级 31bf6822
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
...@@ -137,12 +138,10 @@ def jax_funcify_Argmax(op, **kwargs): ...@@ -137,12 +138,10 @@ def jax_funcify_Argmax(op, **kwargs):
# NumPy does not support multiple axes for argmax; this is a # NumPy does not support multiple axes for argmax; this is a
# work-around # work-around
keep_axes = jnp.array( keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
# Not-reduced axes in front # Not-reduced axes in front
transposed_x = jnp.transpose( transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
) )
kept_shape = transposed_x.shape[: len(keep_axes)] kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :] reduced_shape = transposed_x.shape[len(keep_axes) :]
...@@ -151,9 +150,9 @@ def jax_funcify_Argmax(op, **kwargs): ...@@ -151,9 +150,9 @@ def jax_funcify_Argmax(op, **kwargs):
# Otherwise reshape would complain citing float arg # Otherwise reshape would complain citing float arg
new_shape = ( new_shape = (
*kept_shape, *kept_shape,
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
) )
reshaped_x = transposed_x.reshape(new_shape) reshaped_x = transposed_x.reshape(tuple(new_shape))
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
......
...@@ -65,9 +65,9 @@ def test_extra_ops(): ...@@ -65,9 +65,9 @@ def test_extra_ops():
@pytest.mark.xfail( @pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"), version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled", reason="JAX Numpy API does not support dynamic shapes",
) )
def test_extra_ops_omni(): def test_extra_ops_dynamic_shapes():
a = matrix("a") a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
......
import numpy as np import numpy as np
import pytest import pytest
from packaging.version import parse as version_parse
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
...@@ -80,11 +79,7 @@ def test_jax_basic_multiout(): ...@@ -80,11 +79,7 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail( def test_jax_max_and_argmax():
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to # Test that a single output of a multi-output `Op` can be used as input to
# another `Op` # another `Op`
x = dvector() x = dvector()
...@@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni(): ...@@ -95,10 +90,6 @@ def test_jax_basic_multiout_omni():
compare_jax_and_py(out_fg, [np.r_[1, 2]]) compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_tensor_basics(): def test_tensor_basics():
y = vector("y") y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论