提交 0fced9aa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Detect cases where `boolean_indexing_sum` does not apply

上级 62f84f53
......@@ -48,7 +48,7 @@ optdb.register(
@node_rewriter([Sum])
def boolean_indexing_sum(fgraph, node):
"""Replace the sum of `AdvancedSubtensor` with boolean indexing.
"""Replace the sum of `AdvancedSubtensor` with exclusively boolean indexing.
JAX cannot JIT-compile functions that use boolean indexing, but can compile
those expressions that can be re-expressed using `jax.numpy.where`. This
......@@ -61,14 +61,21 @@ def boolean_indexing_sum(fgraph, node):
if not isinstance(operand, TensorVariable):
return
# If it's not a scalar reduction, it couldn't have been a pure boolean mask
if node.outputs[0].ndim != 0:
return
if operand.owner is None:
return
if not isinstance(operand.owner.op, AdvancedSubtensor):
return
x = operand.owner.inputs[0]
cond = operand.owner.inputs[1]
# Get out if AdvancedSubtensor has more than a single indexing operation
if len(operand.owner.inputs) > 2:
return
[x, cond] = operand.owner.inputs
if not isinstance(cond, TensorVariable):
return
......@@ -76,6 +83,8 @@ def boolean_indexing_sum(fgraph, node):
if not cond.type.dtype == "bool":
return
# Output must be a scalar, since pure boolean indexing returns a vector
# No need to worry about axis
out = at.sum(at.where(cond, x, 0))
return out.owner.outputs
......
......@@ -5,6 +5,7 @@ import pytensor.tensor as at
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import subtensor as at_subtensor
from pytensor.tensor.rewriting.jax import boolean_indexing_sum
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -93,10 +94,22 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
improvement over its user interface.
"""
x_at = at.vector("x")
x_at = at.matrix("x")
out_at = x_at[x_at < 0].sum()
out_fg = FunctionGraph([x_at], [out_at])
compare_jax_and_py(out_fg, [np.arange(-5, 5).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)])
def test_boolean_indexing_sum_not_applicable():
"""Test that boolean_indexing_sum does not return an invalid replacement in cases where it doesn't apply."""
x = at.matrix("x")
out = x[x[:, 0] < 0, :].sum(axis=-1)
fg = FunctionGraph([x], [out])
assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None
out = x[x[:, 0] < 0, 0].sum()
fg = FunctionGraph([x], [out])
assert boolean_indexing_sum.transform(fg, fg.outputs[0].owner) is None
def test_jax_IncSubtensor():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论