提交 2086aeb8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize and simplify `local_reduce_join`

上级 b2c62589
......@@ -91,6 +91,7 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize,
register_useless,
)
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
......@@ -1628,66 +1629,53 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
@node_rewriter([CAReduce])
def local_reduce_join(fgraph, node):
"""
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
Notes
-----
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
all cases.
Currently we must reduce on axis 0. It is probably extensible to the case
where we join and reduce on the same set of axis.
When a, b have a dim length of 1 along the join axis
"""
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join):
join_node = node.inputs[0].owner
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return
if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
return None
if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum):
# Support only 2 inputs for now
if len(join_node.inputs) != 3:
return
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
return
elif len(join_node.inputs) <= 2:
# This is a useless join that should get removed by another rewrite?
return
[joined_out] = node.inputs
joined_node = joined_out.owner
join_axis_tensor, *joined_inputs = joined_node.inputs
new_inp = []
for inp in join_node.inputs[1:]:
inp = inp.owner
if not inp:
return
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != (
"x",
*range(inp.inputs[0].ndim),
):
return
new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp)
n_joined_inputs = len(joined_inputs)
if n_joined_inputs < 2:
# Let some other rewrite get rid of this useless Join
return None
if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
# We don't rewrite if a single Elemwise cannot take all inputs at once
return None
if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return
if not isinstance(join_axis_tensor, Constant):
return None
join_axis = join_axis_tensor.data
reduce_axis = node.op.axis
# Check whether reduction happens on joined axis
reduce_op = node.op
reduce_axis = reduce_op.axis
if reduce_axis is None:
reduce_axis = tuple(range(node.inputs[0].ndim))
if joined_out.type.ndim > 1:
return None
elif reduce_axis != (join_axis,):
return None
if len(reduce_axis) != 1 or 0 not in reduce_axis:
return
# Check all inputs are broadcastable along the join axis and squeeze those dims away
new_inputs = []
for inp in joined_inputs:
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
new_inputs.append(new_input)
# We add the new check late to don't add extra warning.
try:
join_axis = get_underlying_scalar_constant_value(
join_node.inputs[0], only_process_constants=True
)
ret = Elemwise(node.op.scalar_op)(*new_inputs)
if join_axis != reduce_axis[0]:
return
except NotScalarConstantError:
return
if ret.dtype != node.outputs[0].dtype:
# The reduction do something about the dtype.
return None
return [ret]
......
......@@ -3231,7 +3231,7 @@ class TestLocalSumProd:
class TestLocalReduce:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
"canonicalize", "specialize", "uncanonicalize"
)
def test_local_reduce_broadcast_all_0(self):
......@@ -3304,62 +3304,94 @@ class TestLocalReduce:
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_join(self):
class TestReduceJoin:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize"
)
@pytest.mark.parametrize(
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
)
def test_local_reduce_join(self, op, nin):
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(pt_max((vx, vy), 0), np.max((x, y), 0)),
(pt_min((vx, vy), 0), np.min((x, y), 0)),
(pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
inputs = (vx, vy, vz)[:nin]
test_values = (x, y, z)[:nin]
out = op(inputs, axis=0)
f = function(inputs, out, mode=self.mode)
np.testing.assert_allclose(
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
)
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out
assert len(topo) <= 2
assert isinstance(topo[-1].op, Elemwise)
def test_type(self):
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)
# Test a case that was bugged in a old PyTensor bug
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
def test_not_supported_axis_none(self):
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
out = pt_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out)
f = function([vx, vy, vz], out, mode=self.mode)
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))
def test_not_supported_unequal_shapes(self):
# Not the same shape along the join axis
vx = matrix(shape=(1, 3))
vy = matrix(shape=(2, 3))
x = np.asarray([[1, 0, 1]], dtype=config.floatX)
y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX)
out = pt_sum(join(0, vx, vy), axis=0)
f = function([vx, vy], out, mode=self.mode)
np.testing.assert_allclose(
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
)
def test_local_useless_adds():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论