提交 2086aeb8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize and simplify `local_reduce_join`

上级 b2c62589
...@@ -91,6 +91,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -91,6 +91,7 @@ from pytensor.tensor.rewriting.basic import (
register_uncanonicalize, register_uncanonicalize,
register_useless, register_useless,
) )
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
...@@ -1628,68 +1629,55 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None: ...@@ -1628,68 +1629,55 @@ def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
@node_rewriter([CAReduce]) @node_rewriter([CAReduce])
def local_reduce_join(fgraph, node): def local_reduce_join(fgraph, node):
""" """
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) CAReduce{scalar.op}(Join(axis=x, a, b), axis=x) -> Elemwise{scalar.op}(a, b)
Notes When a, b have a dim length of 1 along the join axis
-----
Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in
all cases.
Currently we must reduce on axis 0. It is probably extensible to the case
where we join and reduce on the same set of axis.
""" """
if node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join): if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Join)):
join_node = node.inputs[0].owner return None
if extract_constant(join_node.inputs[0], only_process_constants=True) != 0:
return
if isinstance(node.op.scalar_op, ps.ScalarMaximum | ps.ScalarMinimum): [joined_out] = node.inputs
# Support only 2 inputs for now joined_node = joined_out.owner
if len(join_node.inputs) != 3: join_axis_tensor, *joined_inputs = joined_node.inputs
return
elif not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
return
elif len(join_node.inputs) <= 2:
# This is a useless join that should get removed by another rewrite?
return
new_inp = [] n_joined_inputs = len(joined_inputs)
for inp in join_node.inputs[1:]: if n_joined_inputs < 2:
inp = inp.owner # Let some other rewrite get rid of this useless Join
if not inp: return None
return if n_joined_inputs > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul):
if not isinstance(inp.op, DimShuffle) or inp.op.new_order != ( # We don't rewrite if a single Elemwise cannot take all inputs at once
"x", return None
*range(inp.inputs[0].ndim),
):
return
new_inp.append(inp.inputs[0])
ret = Elemwise(node.op.scalar_op)(*new_inp)
if ret.dtype != node.outputs[0].dtype: if not isinstance(join_axis_tensor, Constant):
# The reduction do something about the dtype. return None
return join_axis = join_axis_tensor.data
reduce_axis = node.op.axis # Check whether reduction happens on joined axis
if reduce_axis is None: reduce_op = node.op
reduce_axis = tuple(range(node.inputs[0].ndim)) reduce_axis = reduce_op.axis
if reduce_axis is None:
if joined_out.type.ndim > 1:
return None
elif reduce_axis != (join_axis,):
return None
if len(reduce_axis) != 1 or 0 not in reduce_axis: # Check all inputs are broadcastable along the join axis and squeeze those dims away
return new_inputs = []
for inp in joined_inputs:
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
new_inputs.append(new_input)
# We add the new check late to don't add extra warning. ret = Elemwise(node.op.scalar_op)(*new_inputs)
try:
join_axis = get_underlying_scalar_constant_value(
join_node.inputs[0], only_process_constants=True
)
if join_axis != reduce_axis[0]: if ret.dtype != node.outputs[0].dtype:
return # The reduction do something about the dtype.
except NotScalarConstantError: return None
return
return [ret] return [ret]
@register_infer_shape @register_infer_shape
......
...@@ -3231,7 +3231,7 @@ class TestLocalSumProd: ...@@ -3231,7 +3231,7 @@ class TestLocalSumProd:
class TestLocalReduce: class TestLocalReduce:
def setup_method(self): def setup_method(self):
self.mode = get_default_mode().including( self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax" "canonicalize", "specialize", "uncanonicalize"
) )
def test_local_reduce_broadcast_all_0(self): def test_local_reduce_broadcast_all_0(self):
...@@ -3304,62 +3304,94 @@ class TestLocalReduce: ...@@ -3304,62 +3304,94 @@ class TestLocalReduce:
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
) )
def test_local_reduce_join(self):
class TestReduceJoin:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize"
)
@pytest.mark.parametrize(
"op, nin", [(pt_sum, 3), (pt_max, 2), (pt_min, 2), (prod, 3)]
)
def test_local_reduce_join(self, op, nin):
vx = matrix() vx = matrix()
vy = matrix() vy = matrix()
vz = matrix() vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(pt_max((vx, vy), 0), np.max((x, y), 0)),
(pt_min((vx, vy), 0), np.min((x, y), 0)),
(pt_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out
inputs = (vx, vy, vz)[:nin]
test_values = (x, y, z)[:nin]
out = op(inputs, axis=0)
f = function(inputs, out, mode=self.mode)
np.testing.assert_allclose(
f(*test_values), getattr(np, op.__name__)(test_values, axis=0)
)
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2
assert isinstance(topo[-1].op, Elemwise)
def test_type(self):
# Test different axis for the join and the reduction # Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail # We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems # on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64")) A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode) f = function([], pt_sum(pt.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10]) np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise) assert isinstance(topo[-1].op, Elemwise)
# Test a case that was bugged in a old PyTensor bug # Test a case that was bugged in a old PyTensor bug
f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode) f = function([], pt_sum(pt.stack([A, A]), axis=1), mode=self.mode)
utt.assert_allclose(f(), [15, 15]) np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise) assert not isinstance(topo[-1].op, Elemwise)
# This case could be rewritten # This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10]) np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise) assert not isinstance(topo[-1].op, Elemwise)
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15]) np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise) assert not isinstance(topo[-1].op, Elemwise)
def test_not_supported_axis_none(self):
# Test that the rewrite does not crash in one case where it # Test that the rewrite does not crash in one case where it
# is not applied. Reported at # is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
out = pt_sum([vx, vy, vz], axis=None) out = pt_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out) f = function([vx, vy, vz], out, mode=self.mode)
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))
def test_not_supported_unequal_shapes(self):
# Not the same shape along the join axis
vx = matrix(shape=(1, 3))
vy = matrix(shape=(2, 3))
x = np.asarray([[1, 0, 1]], dtype=config.floatX)
y = np.asarray([[4, 0, 1], [2, 1, 1]], dtype=config.floatX)
out = pt_sum(join(0, vx, vy), axis=0)
f = function([vx, vy], out, mode=self.mode)
np.testing.assert_allclose(
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
)
def test_local_useless_adds(): def test_local_useless_adds():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论