提交 0fced9aa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Detect cases where `boolean_indexing_sum` does not apply

上级 62f84f53
...@@ -48,7 +48,7 @@ optdb.register( ...@@ -48,7 +48,7 @@ optdb.register(
@node_rewriter([Sum]) @node_rewriter([Sum])
def boolean_indexing_sum(fgraph, node): def boolean_indexing_sum(fgraph, node):
"""Replace the sum of `AdvancedSubtensor` with boolean indexing. """Replace the sum of `AdvancedSubtensor` with exclusively boolean indexing.
JAX cannot JIT-compile functions that use boolean indexing, but can compile JAX cannot JIT-compile functions that use boolean indexing, but can compile
those expressions that can be re-expressed using `jax.numpy.where`. This those expressions that can be re-expressed using `jax.numpy.where`. This
...@@ -61,14 +61,21 @@ def boolean_indexing_sum(fgraph, node): ...@@ -61,14 +61,21 @@ def boolean_indexing_sum(fgraph, node):
if not isinstance(operand, TensorVariable): if not isinstance(operand, TensorVariable):
return return
# If it's not a scalar reduction, it couldn't have been a pure boolean mask
if node.outputs[0].ndim != 0:
return
if operand.owner is None: if operand.owner is None:
return return
if not isinstance(operand.owner.op, AdvancedSubtensor): if not isinstance(operand.owner.op, AdvancedSubtensor):
return return
x = operand.owner.inputs[0] # Get out if AdvancedSubtensor has more than a single indexing operation
cond = operand.owner.inputs[1] if len(operand.owner.inputs) > 2:
return
[x, cond] = operand.owner.inputs
if not isinstance(cond, TensorVariable): if not isinstance(cond, TensorVariable):
return return
...@@ -76,6 +83,8 @@ def boolean_indexing_sum(fgraph, node): ...@@ -76,6 +83,8 @@ def boolean_indexing_sum(fgraph, node):
if not cond.type.dtype == "bool": if not cond.type.dtype == "bool":
return return
# Output must be a scalar, since pure boolean indexing returns a vector
# No need to worry about axis
out = at.sum(at.where(cond, x, 0)) out = at.sum(at.where(cond, x, 0))
return out.owner.outputs return out.owner.outputs
......
...@@ -5,6 +5,7 @@ import pytensor.tensor as at ...@@ -5,6 +5,7 @@ import pytensor.tensor as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import subtensor as at_subtensor from pytensor.tensor import subtensor as at_subtensor
from pytensor.tensor.rewriting.jax import boolean_indexing_sum
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -93,10 +94,22 @@ def test_jax_Subtensor_boolean_mask_reexpressible(): ...@@ -93,10 +94,22 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
improvement over its user interface. improvement over its user interface.
""" """
x_at = at.vector("x") x_at = at.matrix("x")
out_at = x_at[x_at < 0].sum() out_at = x_at[x_at < 0].sum()
out_fg = FunctionGraph([x_at], [out_at]) out_fg = FunctionGraph([x_at], [out_at])
compare_jax_and_py(out_fg, [np.arange(-5, 5).astype(config.floatX)]) compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)])
def test_boolean_indexing_sum_not_applicable():
"""Test that boolean_indexing_sum does not return an invalid replacement in cases where it doesn't apply."""
x = at.matrix("x")
out = x[x[:, 0] < 0, :].sum(axis=-1)
fg = FunctionGraph([x], [out])
assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None
out = x[x[:, 0] < 0, 0].sum()
fg = FunctionGraph([x], [out])
assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论