提交 3b0a97b7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extend local_sum_prod_of_mul rewrite to non-scalar terms

Also: * Separates the sum of negation rewrite * Fixes bug in partial prod reduction
上级 2a21580d
......@@ -1190,78 +1190,80 @@ def local_neg_to_mul(fgraph, node):
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_mul_by_scalar(fgraph, node):
def local_sum_prod_of_mul(fgraph, node):
"""
sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth)
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
or
prod(scalar * smth) -> scalar ** size(smth) * prod(smth)
prod(-smth) -> -1 ** size(smth) * prod(smth)
prod(a * X) -> (a ** size(X)) * prod(X)
TODO: In the case where not all axis overlap with broadcast dimensions,
consider introducing an outer reduction after factoring out the compatible reduced dimensions
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, (Sum, Prod)):
(node_inps,) = node.inputs
if node_inps.owner and node_inps.owner.op == mul:
terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)]
if len(scalars) == 0:
return
[node_inps] = node.inputs
if not (node_inps.owner and node_inps.owner.op == mul):
return None
non_scalars = [t for t in terms if not all(t.broadcastable)]
# Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0:
new_op_input_nb_elements = 1
new_op_output = 1
elif len(non_scalars) == 1:
new_op_input_nb_elements = non_scalars[0].size
new_op_output = node.op(non_scalars[0])
else:
new_op_input = mul(*non_scalars)
# We assume that errors always come from the prod/mul op in the
# original computational graph, and therefore need to only
# copy over its output stacktrace.
copy_stack_trace(node.outputs, new_op_input)
new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input)
if len(non_scalars) != 0:
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, new_op_output)
# If `node.op` is a `Prod`, then the scalars need to be raised to
# the power of the number of elements in the input to the `Prod`
if isinstance(node.op, Prod) and new_op_input_nb_elements != 1:
scalars = [s**new_op_input_nb_elements for s in scalars]
# Scale the output of the op by the scalars and return as
# replacement for the original output
mul_inputs = scalars
if new_op_input_nb_elements != 1:
mul_inputs.append(new_op_output)
if len(mul_inputs) == 1:
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, mul_inputs)
return mul_inputs
else:
ret = mul(*mul_inputs)
# Copy over stacktrace from previous output to new mul op,
# for same reason as above.
copy_stack_trace(node.outputs, [ret] + mul_inputs)
reduced_axes = node.op.axis
if reduced_axes is None:
reduced_axes = tuple(range(node_inps.type.ndim))
return [ret]
# Separate terms that can be moved out of the Sum/Prod and those that cannot
outer_terms = []
inner_terms = []
for term in node_inps.owner.inputs:
term_bcast = term.type.broadcastable
if all(term_bcast[i] for i in reduced_axes):
outer_terms.append(term.squeeze(reduced_axes))
else:
inner_terms.append(term)
if not outer_terms:
return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else:
outer_term = mul(*outer_terms)
if not inner_terms:
inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else:
inner_term = mul(*inner_terms)
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input
if isinstance(node.op, Prod) and inner_term:
dtype = inner_term.dtype
n_reduced_elements = prod(
[inner_term.shape[i].astype(dtype) for i in reduced_axes]
)
outer_term = outer_term**n_reduced_elements
# Sum/Prod is useless, just return the outer_term
if not inner_term:
new_out = outer_term
else:
reduced_inner_term = node.op(inner_term)
new_out = outer_term * reduced_inner_term
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg:
copy_stack_trace(node.outputs, new_out)
return [new_out]
@register_specialize
@node_rewriter([Sum])
def local_sum_of_neg_to_neg_of_sum(fgraph, node):
"""Rewrite sum(-X) -> -sum(X)."""
[node_inps] = node.inputs
if node_inps.owner and node_inps.owner.op == neg:
s = node.op(node_inps.owner.inputs[0])
ret = neg(s)
# There are never errors in the negative op, thus
......
......@@ -92,6 +92,7 @@ from pytensor.tensor.rewriting.math import (
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
local_sum_prod_of_mul,
mul_canonizer,
parse_mul_tree,
perform_sigm_times_exp,
......@@ -2503,7 +2504,7 @@ class TestLocalSumProd:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize", "specialize")
def test_local_sum_prod_mul_by_scalar(self):
def test_local_sum_prod_of_scalar_mul(self):
# Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and
# Prod ops in six cases each :
# 1-the inputs to the mul contain a scalar and no non-scalar
......@@ -2653,6 +2654,157 @@ class TestLocalSumProd:
axis=(0,),
)
def test_sum_of_non_scalar_mul(self):
mode = Mode("vm", optimizer="None")
rewrite = out2in(local_sum_prod_of_mul)
row1 = matrix(shape=(1, None), dtype="float64")
row2 = matrix(shape=(1, None), dtype="float64")
col1 = matrix(shape=(None, 1), dtype="float64")
col2 = matrix(shape=(None, 1), dtype="float64")
mat1 = matrix(shape=(None, None), dtype="float64")
mat2 = matrix(shape=(None, None), dtype="float64")
inputs = [row1, row2, col1, col2, mat1, mat2]
test_vals = [
np.random.random((1, 2)),
np.random.random((1, 2)),
np.random.random((2, 1)),
np.random.random((2, 1)),
np.random.random((2, 2)),
np.random.random((2, 2)),
]
for out, expected_out in [
(
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None),
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=None),
),
(
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=0),
mul(row1.squeeze(), row2.squeeze())
* mul(mat1, mat2, col1, col2).sum(axis=0),
),
(
mul(row1, mat1, mat2, col1, col2).sum(axis=0),
row1.squeeze() * mul(mat1, mat2, col1, col2).sum(axis=0),
),
(
mul(row1, row2, mat1, mat2, col1, col2).sum(axis=1),
mul(col1.squeeze(), col2.squeeze())
* mul(row1, row2, mat1, mat2).sum(axis=1),
),
(
mul(row1, row2, mat1, mat2, col2).sum(axis=1),
col2.squeeze() * mul(row1, row2, mat1, mat2).sum(axis=1),
),
(
mul(row1, row2).sum(axis=1),
mul(row1, row2).sum(axis=1),
),
(
mul(row1, row2).sum(axis=0),
mul(row1.squeeze(), row2.squeeze()),
),
(
mul(row1, col1).sum(axis=0),
row1.squeeze() * col1.sum(axis=0),
),
]:
out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore")
rewritten_out = rewrite_graph(out, custom_rewrite=rewrite)
assert equal_computations([rewritten_out], [expected_out])
rewritten_out_fn = pytensor.function(
inputs, rewritten_out, mode=mode, on_unused_input="ignore"
)
np.testing.assert_allclose(
out_fn(*test_vals),
rewritten_out_fn(*test_vals),
)
def test_prod_of_non_scalar_mul(self):
mode = Mode("vm", optimizer="None")
rewrite = out2in(local_sum_prod_of_mul)
scl1 = matrix(shape=(1, 1), dtype="float64")
row1 = matrix(shape=(1, None), dtype="float64")
row2 = matrix(shape=(1, None), dtype="float64")
col1 = matrix(shape=(None, 1), dtype="float64")
col2 = matrix(shape=(None, 1), dtype="float64")
mat1 = matrix(shape=(None, None), dtype="float64")
mat2 = matrix(shape=(None, None), dtype="float64")
inputs = [scl1, row1, row2, col1, col2, mat1, mat2]
test_vals = [
np.random.random((1, 1)),
np.random.random((1, 2)),
np.random.random((1, 2)),
np.random.random((2, 1)),
np.random.random((2, 1)),
np.random.random((2, 2)),
np.random.random((2, 2)),
]
for out, expected_out in [
(
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None),
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=None),
),
(
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0),
(
mul(row1.squeeze(), row2.squeeze())
** prod([mul(mat1, mat2, col1, col2).shape[0]])
* mul(mat1, mat2, col1, col2).prod(axis=0)
),
),
(
mul(row1, mat1, mat2, col1, col2).prod(axis=0),
(
row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]])
* mul(mat1, mat2, col1, col2).prod(axis=0)
),
),
(
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1),
(
mul(col1.squeeze(), col2.squeeze())
** prod([mul(row1, row2, mat1, mat2).shape[1]])
* mul(row1, row2, mat1, mat2).prod(axis=1)
),
),
(
mul(row1, row2).prod(axis=0),
mul(row1.squeeze(), row2.squeeze()),
),
(
mul(row1, col1).prod(axis=0),
(row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)),
),
(
mul(scl1, mat1, row1).prod(axis=None),
(
scl1.squeeze()
** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]])
* mul(mat1, row1).prod(axis=None)
),
),
]:
out_fn = pytensor.function(inputs, out, mode=mode, on_unused_input="ignore")
rewritten_out = rewrite_graph(out, custom_rewrite=rewrite)
assert equal_computations([rewritten_out], [expected_out])
rewritten_out_fn = pytensor.function(
inputs, rewritten_out, mode=mode, on_unused_input="ignore"
)
np.testing.assert_allclose(
out_fn(*test_vals),
rewritten_out_fn(*test_vals),
)
def test_local_sum_prod_all_to_none(self):
a = tensor3()
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论