提交 37474223 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Simplify the `IncSubtensor` dispatcher

上级 b3f12b26
import jax
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -33,7 +31,7 @@ slice length. ...@@ -33,7 +31,7 @@ slice length.
""" """
def assert_indices_jax_compatible(node, idx_list): def subtensor_assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.tensor.var import TensorVariable from pytensor.tensor.var import TensorVariable
...@@ -55,7 +53,7 @@ def assert_indices_jax_compatible(node, idx_list): ...@@ -55,7 +53,7 @@ def assert_indices_jax_compatible(node, idx_list):
def jax_funcify_Subtensor(op, node, **kwargs): def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
assert_indices_jax_compatible(node, idx_list) subtensor_assert_indices_jax_compatible(node, idx_list)
def subtensor_constant(x, *ilists): def subtensor_constant(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list) indices = indices_from_subtensor(ilists, idx_list)
...@@ -69,25 +67,19 @@ def jax_funcify_Subtensor(op, node, **kwargs): ...@@ -69,25 +67,19 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register(IncSubtensor) @jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1) @jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, **kwargs): def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None:
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
return x.at[indices].set(y) return x.at[indices].set(y)
else: else:
jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
return x.at[indices].add(y) return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
...@@ -100,23 +92,17 @@ def jax_funcify_IncSubtensor(op, **kwargs): ...@@ -100,23 +92,17 @@ def jax_funcify_IncSubtensor(op, **kwargs):
@jax_funcify.register(AdvancedIncSubtensor) @jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, **kwargs): def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None: def jax_fn(x, indices, y):
return x.at[indices].set(y)
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else: else:
jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
return x.at[indices].add(y) return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y) return jax_fn(x, ilist, y)
......
import numpy as np import numpy as np
import pytest import pytest
from jax._src.errors import NonConcreteBooleanIndexError
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -179,7 +178,11 @@ def test_jax_IncSubtensor(): ...@@ -179,7 +178,11 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensors_unsupported(): @pytest.mark.xfail(
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
)
def test_jax_IncSubtensor_boolean_mask_reexpressible():
"""Some boolean logic can be re-expressed and JIT-compiled"""
rng = np.random.default_rng(213234) rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
...@@ -188,30 +191,28 @@ def test_jax_IncSubtensors_unsupported(): ...@@ -188,30 +191,28 @@ def test_jax_IncSubtensors_unsupported():
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
with pytest.raises( compare_jax_and_py(out_fg, [])
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
mask_at = at.as_tensor_variable(x_np) > 0 mask_at = at.as_tensor(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
with pytest.raises( compare_jax_and_py(out_fg, [])
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, []) def test_jax_IncSubtensors_unsupported():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"): compare_jax_and_py(out_fg, [])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"): compare_jax_and_py(out_fg, [])
compare_jax_and_py(out_fg, [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论