提交 7ed706ae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dispatch for Split Op

上级 1dfb0455
import warnings
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocDiag, AllocDiag,
...@@ -11,8 +15,11 @@ from pytensor.tensor.basic import ( ...@@ -11,8 +15,11 @@ from pytensor.tensor.basic import (
Join, Join,
MakeVector, MakeVector,
ScalarFromTensor, ScalarFromTensor,
Split,
TensorFromScalar, TensorFromScalar,
get_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError
@jax_funcify.register(AllocDiag) @jax_funcify.register(AllocDiag)
...@@ -68,6 +75,53 @@ def jax_funcify_Join(op, **kwargs): ...@@ -68,6 +75,53 @@ def jax_funcify_Join(op, **kwargs):
return join return join
@jax_funcify.register(Split)
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
"Split node does not have constant axis. Jax implementation will likely fail"
)
try:
constant_splits = np.array(
[
get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
except (ValueError, NotScalarConstantError):
constant_splits = None
warnings.warn(
"Split node does not have constant split positions. Jax implementation will likely fail"
)
def split(x, axis, splits):
if constant_axis is not None:
axis = constant_axis
if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
else:
cumsum_splits = jnp.cumsum(splits[:-1])
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != x.shape[axis]:
raise ValueError(
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
)
if np.any(splits < 0):
raise ValueError("Split sizes cannot be negative")
return jnp.split(x, cumsum_splits, axis=axis)
return split
@jax_funcify.register(ExtractDiag) @jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op, **kwargs): def jax_funcify_ExtractDiag(op, **kwargs):
offset = op.offset offset = op.offset
......
import jax.errors
import numpy as np import numpy as np
import pytest import pytest
import pytensor
import pytensor.tensor.basic as at import pytensor.tensor.basic as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.tensor.type import matrix, scalar, vector from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -102,6 +104,80 @@ def test_jax_Join(): ...@@ -102,6 +104,80 @@ def test_jax_Join():
) )
class TestJaxSplit:
def test_basic(self):
a = matrix("a")
a_splits = at.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py(
fg,
[
np.zeros((6, 4)).astype(config.floatX),
],
)
a = matrix("a", shape=(6, None))
a_splits = at.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py(
fg,
[
np.zeros((6, 4)).astype(config.floatX),
],
)
def test_runtime_errors(self):
a = matrix("a")
a_splits = at.split(a, splits_size=[2, 2, 2], n_splits=2, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Length of splits is not equal to n_splits"
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
a_splits = at.split(a, splits_size=[2, 4], n_splits=3, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Length of splits is not equal to n_splits"
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError, match="Split sizes do not sum up to input length along axis: 7"
):
fn(np.zeros((7, 4), dtype=pytensor.config.floatX))
a_splits = at.split(a, splits_size=[2, -4, 8], n_splits=3, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX")
with pytest.raises(
ValueError,
match="Split sizes cannot be negative",
):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
def test_jax_split_not_supported(self):
a = matrix("a", shape=(6, None))
a_splits = at.split(a, splits_size=[2, a.shape[1] - 2], n_splits=2, axis=1)
with pytest.warns(
UserWarning, match="Split node does not have constant split positions."
):
fn = pytensor.function([a], a_splits, mode="JAX")
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it
with pytest.raises(AttributeError):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
split_axis = iscalar("split_axis")
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
with pytest.raises(jax.errors.TracerIntegerConversionError):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
def test_jax_eye(): def test_jax_eye():
"""Tests jaxification of the Eye operator""" """Tests jaxification of the Eye operator"""
out = at.eye(3) out = at.eye(3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论