提交 7ed706ae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dispatch for Split Op

上级 1dfb0455
import warnings
import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
Alloc,
AllocDiag,
......@@ -11,8 +15,11 @@ from pytensor.tensor.basic import (
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
@jax_funcify.register(AllocDiag)
......@@ -68,6 +75,53 @@ def jax_funcify_Join(op, **kwargs):
return join
@jax_funcify.register(Split)
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
"Split node does not have constant axis. Jax implementation will likely fail"
)
try:
constant_splits = np.array(
[
get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
except (ValueError, NotScalarConstantError):
constant_splits = None
warnings.warn(
"Split node does not have constant split positions. Jax implementation will likely fail"
)
def split(x, axis, splits):
if constant_axis is not None:
axis = constant_axis
if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
else:
cumsum_splits = jnp.cumsum(splits[:-1])
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != x.shape[axis]:
raise ValueError(
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
)
if np.any(splits < 0):
raise ValueError("Split sizes cannot be negative")
return jnp.split(x, cumsum_splits, axis=axis)
return split
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
......
import jax.errors
import numpy as np
import pytest
import pytensor
import pytensor.tensor.basic as at
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -102,6 +104,80 @@ def test_jax_Join():
)
class TestJaxSplit:
def test_basic(self):
a = matrix("a")
a_splits = at.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py(
fg,
[
np.zeros((6, 4)).astype(config.floatX),
],
)
a = matrix("a", shape=(6, None))
a_splits = at.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py(
fg,
[
np.zeros((6, 4)).astype(config.floatX),
],
)
def test_runtime_errors(self):
a = matrix("a")
a_splits = at.split(a, splits_size=[2, 2, 2], n_splits=2, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Length of splits is not equal to n_splits"
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
a_splits = at.split(a, splits_size=[2, 4], n_splits=3, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Length of splits is not equal to n_splits"
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Split sizes do not sum up to input length along axis: 7"
):
fn(np.zeros((7, 4), dtype=pytensor.config.floatX))
a_splits = at.split(a, splits_size=[2, -4, 8], n_splits=3, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError,
match="Split sizes cannot be negative",
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
def test_jax_split_not_supported(self):
a = matrix("a", shape=(6, None))
a_splits = at.split(a, splits_size=[2, a.shape[1] - 2], n_splits=2, axis=1)
with pytest.warns(
UserWarning, match="Split node does not have constant split positions."
):
fn = pytensor.function([a], a_splits, mode="JAX")
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it
with pytest.raises(AttributeError):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
split_axis = iscalar("split_axis")
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
with pytest.raises(jax.errors.TracerIntegerConversionError):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = at.eye(3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论