提交 0087e562 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Add rewrites to re-express boolean indexing logic

上级 37474223
import pytensor.tensor.rewriting.basic import pytensor.tensor.rewriting.basic
import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops import pytensor.tensor.rewriting.extra_ops
# Register JAX specializations
import pytensor.tensor.rewriting.jax
import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.special
......
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor.var import TensorVariable
import pytensor.tensor as at
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from pytensor.tensor.math import Sum
@node_rewriter([AdvancedIncSubtensor])
def boolean_indexing_set_or_inc(fgraph, node):
"""Replace `AdvancedIncSubtensor` when using boolean indexing using `Switch`.
JAX cannot JIT-compile functions that use boolean indexing to set values in
an array. A workaround is to re-express this logic using `jax.numpy.where`.
This rewrite allows to improve upon JAX's API.
"""
op = node.op
x = node.inputs[0]
y = node.inputs[1]
cond = node.inputs[2]
if not isinstance(cond, TensorVariable):
return
if not cond.type.dtype == 'bool':
return
if op.set_instead_of_inc:
out = at.where(cond, y, x)
return out.owner.outputs
else:
out = at.where(cond, x + y, x)
return out.owner.outputs
optdb.register(
"jax_boolean_indexing_set_or_inc", in2out(boolean_indexing_set_or_inc), "jax", position=100
)
@node_rewriter([Sum])
def boolean_indexing_sum(fgraph, node):
"""Replace the sum of `AdvancedSubtensor` with boolean indexing.
JAX cannot JIT-compile functions that use boolean indexing, but can compile
those expressions that can be re-expressed using `jax.numpy.where`. This
rewrite re-rexpressed the model on the behalf of the user and thus allows to
improve upon JAX's API.
"""
operand = node.inputs[0]
if not isinstance(operand, TensorVariable):
return
if operand.owner is None:
return
if not isinstance(operand.owner.op, AdvancedSubtensor):
return
x = operand.owner.inputs[0]
cond = operand.owner.inputs[1]
if not isinstance(cond, TensorVariable):
return
if not cond.type.dtype == 'bool':
return
out = at.sum(at.where(cond, x, 0))
return out.owner.outputs
optdb.register(
"jax_boolean_indexing_sum", in2out(boolean_indexing_sum), "jax", position=100
)
...@@ -80,15 +80,21 @@ def test_jax_Subtensor_boolean_mask(): ...@@ -80,15 +80,21 @@ def test_jax_Subtensor_boolean_mask():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side."
)
def test_jax_Subtensor_boolean_mask_reexpressible(): def test_jax_Subtensor_boolean_mask_reexpressible():
"""Some boolean logic can be re-expressed and JIT-compiled""" """Summing values with boolean indexing.
x_at = at.arange(-5, 5)
This test ensures that the sum of an `AdvancedSubtensor` `Op`s with boolean
indexing is replaced with the sum of an equivalent `Switch` `Op`, using the
`jax_boolean_indexing_sum` rewrite.
JAX forces users to re-express this logic manually, so this is an
improvement over its user interface.
"""
x_at = at.vector("x")
out_at = x_at[x_at < 0].sum() out_at = x_at[x_at < 0].sum()
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([x_at], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [np.arange(-5, 5).astype(config.floatX)])
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
...@@ -177,42 +183,42 @@ def test_jax_IncSubtensor(): ...@@ -177,42 +183,42 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
@pytest.mark.xfail( out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
reason="Re-expressible boolean logic. We need a rewrite PyTensor-side to remove the DimShuffle."
)
def test_jax_IncSubtensor_boolean_mask_reexpressible():
"""Some boolean logic can be re-expressed and JIT-compiled"""
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
mask_at = at.as_tensor(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
mask_at = at.as_tensor(x_np) > 0 st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0) out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensors_unsupported(): def test_jax_IncSubtensor_boolean_indexing_reexpressible():
"""Setting or incrementing values with boolean indexing.
This test ensures that `AdvancedIncSubtensor` `Op`s with boolean indexing is
replaced with an equivalent `Switch` `Op`, using the
`jax_boolean_indexing_set_of_inc` rewrite.
JAX forces users to re-express this logic manually, so this is an
improvement over its user interface.
"""
rng = np.random.default_rng(213234) rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_np = rng.uniform(-1, 1, size=(4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) x_at = at.matrix("x")
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) mask_at = at.as_tensor(x_at) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([x_at], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [x_np])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) mask_at = at.as_tensor(x_at) > 0
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) out_at = at_subtensor.inc_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at]) out_fg = FunctionGraph([x_at], [out_at])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [x_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论