提交 b2c62589 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize and rename `local_reduce_chain`

上级 5b9c07ec
......@@ -100,7 +100,11 @@ from pytensor.tensor.type import (
values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan,
)
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
......@@ -1575,42 +1579,48 @@ def local_sum_prod_all_to_none(fgraph, node):
@register_canonicalize
@node_rewriter([Sum, Prod])
def local_op_of_op(fgraph, node):
@node_rewriter([CAReduce])
def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
"""
Prod(Prod()) -> single Prod()
or
Sum(Sum()) -> single Sum()
or any CAReduce(Careduce(x)) of the same type
"""
op_type = Sum if isinstance(node.op, Sum) else Prod
(node_inps,) = node.inputs
out_dtype = node.op.dtype
# This is done to make sure the rewrite doesn't affect other
# computations.
if len(fgraph.clients[node_inps]) == 1:
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None:
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
# figure out which axes were in the original sum
newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(
list(node_inps.owner.op.axis) + list(node.op.axis)
)
[inner_reduce] = node.inputs
if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)):
return None
# Don't apply rewrite if inner_reduce is used elsewhere
if len(fgraph.clients[inner_reduce]) > 1:
return None
# Check if CAReduces have the same scalar op
outer_op: CAReduce = node.op
inner_op = inner_reduce.owner.op
if outer_op.scalar_op != inner_op.scalar_op:
return None
combined = op_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
outer_axis = outer_op.axis
inner_axis = inner_op.axis
[x] = inner_reduce.owner.inputs
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if outer_axis is None or inner_axis is None:
return [outer_op.clone(axis=None)(x)]
# Merge axis
newaxis = list(inner_axis)
for i in outer_axis:
new_i = i
for ii in inner_axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(inner_axis) + len(outer_axis)
return [outer_op.clone(axis=sorted(newaxis))(x)]
@register_canonicalize
......
......@@ -101,6 +101,7 @@ from pytensor.tensor.rewriting.math import (
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
local_reduce_chain,
local_sum_prod_of_mul_or_div,
mul_canonizer,
parse_mul_tree,
......@@ -2497,6 +2498,168 @@ class TestLocalMergeSwitchSameCond:
assert debugprint(g, file="str").count("Switch") == 1
class TestReduceChain:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize", "specialize")
def test_local_sum_prod_all_to_none(self):
a = tensor3()
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
# test sum
f = function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
# test prod
f = function([a], a.prod(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
# test sum
f = function([a], a.sum([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
# test prod
f = function([a], a.prod([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
def test_local_sum_sum_prod_prod(self):
a = tensor3()
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [
(0, 0),
(1, 0),
(2, 0),
(0, 1),
(1, 1),
(2, 1),
((0, 1), 0),
((1, 2), 0),
(0, (0, 1)),
(1, (0, 1)),
(2, (0, 1)),
]
def my_prod(data, d, dd):
# This prod when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.prod(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.prod(d[1]).prod(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.prod(d).prod(dd[1]).prod(dd[0])
def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).sum(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).sum(dd)
else:
dd = sorted(dd)
return data.sum(d).sum(dd[1]).sum(dd[0])
def my_sum_prod(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.sum(d).prod(dd[1]).prod(dd[0])
for d, dd in dims:
expected = my_sum(input, d, dd)
f = function([a], a.sum(d).sum(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.sum(d).sum(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.sum(None).sum(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
# test prod
for d, dd in dims:
expected = my_prod(input, d, dd)
f = function([a], a.prod(d).prod(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode)
utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.prod(d).prod(None), mode=self.mode)
utt.assert_allclose(f(input), input.prod(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.prod(None).prod(), mode=self.mode)
utt.assert_allclose(f(input), input.prod())
assert len(f.maker.fgraph.apply_nodes) == 1
# Test that sum prod didn't get rewritten.
for d, dd in dims:
expected = my_sum_prod(input, d, dd)
f = function([a], a.sum(d).prod(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 2
for d, dd in dims[:6]:
f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 2
for d in [0, 1, 2]:
f = function([a], a.sum(d).prod(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 2
f = function([a], a.sum(None).prod(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
def test_local_sum_sum_int8(self):
"""Test that `local_sum_sum` works when combining two sums on an int8 array.
This is a regression test for ticket gh-356.
"""
x = tensor3(dtype="int8")
y = x.sum(axis=0).sum(axis=1)
with config.change_flags(on_opt_error="raise"):
# This compilation would fail prior to fix.
function([x], y)
def test_local_sum_sum_dtype(self):
"""Test that `local_sum_sum` works when specifying dtypes manually."""
x = tensor3(dtype="int8")
y = x.sum(axis=0, dtype="int32").sum(axis=1, dtype="int64")
with config.change_flags(on_opt_error="raise"):
# This compilation would fail prior to fix.
function([x], y)
def test_all(self):
x = tensor3(dtype=bool)
out = x.all(axis=-1).all(axis=0)
fg = FunctionGraph([x], [out], clone=False)
[new_out] = local_reduce_chain.transform(fg, out.owner)
assert equal_computations([new_out], [x.all(axis=(0, 2))])
class TestLocalSumProd:
"""Test sum/prod rewrites."""
......@@ -2813,133 +2976,6 @@ class TestLocalSumProd:
rewritten_out_fn(*test_vals),
)
def test_local_sum_prod_all_to_none(self):
a = tensor3()
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
# test sum
f = function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
# test prod
f = function([a], a.prod(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
# test sum
f = function([a], a.sum([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
# test prod
f = function([a], a.prod([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
def test_local_sum_sum_prod_prod(self):
a = tensor3()
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [
(0, 0),
(1, 0),
(2, 0),
(0, 1),
(1, 1),
(2, 1),
((0, 1), 0),
((1, 2), 0),
(0, (0, 1)),
(1, (0, 1)),
(2, (0, 1)),
]
def my_prod(data, d, dd):
# This prod when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.prod(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.prod(d[1]).prod(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.prod(d).prod(dd[1]).prod(dd[0])
def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).sum(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).sum(dd)
else:
dd = sorted(dd)
return data.sum(d).sum(dd[1]).sum(dd[0])
def my_sum_prod(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.sum(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.sum(d[1]).sum(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.sum(d).prod(dd[1]).prod(dd[0])
for d, dd in dims:
expected = my_sum(input, d, dd)
f = function([a], a.sum(d).sum(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.sum(d).sum(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.sum(None).sum(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
# test prod
for d, dd in dims:
expected = my_prod(input, d, dd)
f = function([a], a.prod(d).prod(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode)
utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.prod(d).prod(None), mode=self.mode)
utt.assert_allclose(f(input), input.prod(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.prod(None).prod(), mode=self.mode)
utt.assert_allclose(f(input), input.prod())
assert len(f.maker.fgraph.apply_nodes) == 1
# Test that sum prod didn't get rewritten.
for d, dd in dims:
expected = my_sum_prod(input, d, dd)
f = function([a], a.sum(d).prod(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 2
for d, dd in dims[:6]:
f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 2
for d in [0, 1, 2]:
f = function([a], a.sum(d).prod(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 2
f = function([a], a.sum(None).prod(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
def test_local_sum_prod_alloc(self):
a = dtensor3()
input = np.asarray(np.arange(2 * 3 * 4).reshape(2, 3, 4), dtype="float64")
......@@ -3005,29 +3041,6 @@ class TestLocalSumProd:
assert topo[-1].op == pt.alloc
assert not any(isinstance(node.op, Sum) for node in topo)
def test_local_sum_sum_int8(self):
"""Test that `local_sum_sum` works when combining two sums on an int8 array.
This is a regression test for ticket gh-356.
"""
x = tensor3(dtype="int8")
y = x.sum(axis=0).sum(axis=1)
with config.change_flags(on_opt_error="raise"):
# This compilation would fail prior to fix.
function([x], y)
def test_local_sum_sum_dtype(self):
"""Test that `local_sum_sum` works when specifying dtypes manually."""
x = tensor3(dtype="int8")
y = x.sum(axis=0, dtype="int32").sum(axis=1, dtype="int64")
with config.change_flags(on_opt_error="raise"):
# This compilation would fail prior to fix.
function([x], y)
def test_local_sum_prod_mul_by_scalar_stack_trace(self):
"""Test that stack trace is copied over correctly for `local_sum_prod_mul_by_scalar`."""
m0 = (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论