提交 c49e395c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Merge rewrite for sum/prod of div with that of mul

上级 3b0a97b7
......@@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node):
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_of_mul(fgraph, node):
def local_sum_prod_of_mul_or_div(fgraph, node):
"""
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
......@@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node):
prod(a * X) -> (a ** size(X)) * prod(X)
It also applies to reduction of X / a,
but not a / X, as that would still require inverting every value in X before the reduction
TODO: In the case where not all axis overlap with broadcast dimensions,
consider introducing an outer reduction after factoring out the compatible reduced dimensions
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
[node_inps] = node.inputs
if not (node_inps.owner and node_inps.owner.op == mul):
if not node_inps.owner:
return None
inner_op = node_inps.owner.op
if not (inner_op == mul or inner_op == true_div):
return None
reduced_axes = node.op.axis
......@@ -1214,28 +1219,40 @@ def local_sum_prod_of_mul(fgraph, node):
reduced_axes = tuple(range(node_inps.type.ndim))
# Separate terms that can be moved out of the Sum/Prod and those that cannot
outer_terms = []
inner_terms = []
for term in node_inps.owner.inputs:
term_bcast = term.type.broadcastable
if all(term_bcast[i] for i in reduced_axes):
outer_terms.append(term.squeeze(reduced_axes))
else:
inner_terms.append(term)
if inner_op == mul:
# Mul accepts arbitrary inputs, so we need to separate into two groups
outer_terms = []
inner_terms = []
for term in node_inps.owner.inputs:
term_bcast = term.type.broadcastable
if all(term_bcast[i] for i in reduced_axes):
outer_terms.append(term.squeeze(reduced_axes))
else:
inner_terms.append(term)
if not outer_terms:
return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else:
outer_term = mul(*outer_terms)
if not outer_terms:
return None
elif len(outer_terms) == 1:
[outer_term] = outer_terms
else:
outer_term = mul(*outer_terms)
if not inner_terms:
inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else:
inner_term = mul(*inner_terms)
if not inner_terms:
inner_term = None
elif len(inner_terms) == 1:
[inner_term] = inner_terms
else:
inner_term = mul(*inner_terms)
else: # true_div
# We only care about removing the denominator out of the reduction
numerator, denominator = node_inps.owner.inputs
denominator_bcast = denominator.type.broadcastable
if all(denominator_bcast[i] for i in reduced_axes):
outer_term = denominator.squeeze(reduced_axes)
inner_term = numerator
else:
return None
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input
......@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node):
)
outer_term = outer_term**n_reduced_elements
# Sum/Prod is useless, just return the outer_term
if not inner_term:
# Sum/Prod is useless, just return the outer_term
# (This can only happen for mul, not division)
new_out = outer_term
else:
reduced_inner_term = node.op(inner_term)
new_out = outer_term * reduced_inner_term
if inner_op == mul:
new_out = outer_term * reduced_inner_term
else:
new_out = reduced_inner_term / outer_term
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
copy_stack_trace(node.outputs, new_out)
......@@ -1510,99 +1531,6 @@ def local_useless_elemwise_comparison(fgraph, node):
return
@register_canonicalize
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_div_dimshuffle(fgraph, node):
"""
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'
or
prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
"""
# It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation or production.
if isinstance(node.op, (Sum, Prod)):
axis = node.op.axis
if axis is None:
axis = list(range(node.inputs[0].ndim))
node_input = node.inputs[0]
if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order
compatible_dims = []
incompatible_dims = []
for ax in axis:
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
compatible_dims.append(ax)
else:
incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
)
if len(compatible_dims) > 0:
optimized_dimshuffle_order = [
ax
for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x")
]
# Removing leading 'x' (since it will be done automatically)
while (
len(optimized_dimshuffle_order) > 0
and optimized_dimshuffle_order[0] == "x"
):
del optimized_dimshuffle_order[0]
# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order,
)(dimshuffle_input)
if isinstance(node.op, Sum):
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
if len(reordered_incompatible_dims) > 0:
rval = at_sum(rval, axis=reordered_incompatible_dims)
elif isinstance(node.op, Prod):
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
dtype = numerator.dtype
rval = true_div(
op_on_compatible_dims,
(
optimized_dimshuffle
** prod(
[
numerator.shape[ax].astype(dtype)
for ax in compatible_dims
]
)
),
)
if len(reordered_incompatible_dims) > 0:
rval = prod(rval, axis=reordered_incompatible_dims)
return [rval]
@register_canonicalize
@node_rewriter([Sum, Prod])
def local_sum_prod_all_to_none(fgraph, node):
......
......@@ -899,7 +899,7 @@ class TestFusion:
),
(fx, fy),
(fxv, fyv),
3,
2,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
......
......@@ -92,7 +92,7 @@ from pytensor.tensor.rewriting.math import (
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
local_sum_prod_of_mul,
local_sum_prod_of_mul_or_div,
mul_canonizer,
parse_mul_tree,
perform_sigm_times_exp,
......@@ -2656,7 +2656,7 @@ class TestLocalSumProd:
def test_sum_of_non_scalar_mul(self):
mode = Mode("vm", optimizer="None")
rewrite = out2in(local_sum_prod_of_mul)
rewrite = out2in(local_sum_prod_of_mul_or_div)
row1 = matrix(shape=(1, None), dtype="float64")
row2 = matrix(shape=(1, None), dtype="float64")
......@@ -2726,7 +2726,7 @@ class TestLocalSumProd:
def test_prod_of_non_scalar_mul(self):
mode = Mode("vm", optimizer="None")
rewrite = out2in(local_sum_prod_of_mul)
rewrite = out2in(local_sum_prod_of_mul_or_div)
scl1 = matrix(shape=(1, 1), dtype="float64")
row1 = matrix(shape=(1, None), dtype="float64")
......@@ -2756,14 +2756,15 @@ class TestLocalSumProd:
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=0),
(
mul(row1.squeeze(), row2.squeeze())
** prod([mul(mat1, mat2, col1, col2).shape[0]])
** prod([mul(mat1, mat2, col1, col2).shape[0].astype("float64")])
* mul(mat1, mat2, col1, col2).prod(axis=0)
),
),
(
mul(row1, mat1, mat2, col1, col2).prod(axis=0),
(
row1.squeeze() ** prod([mul(mat1, mat2, col1, col2).shape[0]])
row1.squeeze()
** prod([mul(mat1, mat2, col1, col2).shape[0].astype("float64")])
* mul(mat1, mat2, col1, col2).prod(axis=0)
),
),
......@@ -2771,7 +2772,7 @@ class TestLocalSumProd:
mul(row1, row2, mat1, mat2, col1, col2).prod(axis=1),
(
mul(col1.squeeze(), col2.squeeze())
** prod([mul(row1, row2, mat1, mat2).shape[1]])
** prod([mul(row1, row2, mat1, mat2).shape[1].astype("float64")])
* mul(row1, row2, mat1, mat2).prod(axis=1)
),
),
......@@ -2781,13 +2782,21 @@ class TestLocalSumProd:
),
(
mul(row1, col1).prod(axis=0),
(row1.squeeze() ** prod([col1.shape[0]]) * col1.prod(axis=0)),
(
row1.squeeze() ** prod([col1.shape[0].astype("float64")])
* col1.prod(axis=0)
),
),
(
mul(scl1, mat1, row1).prod(axis=None),
(
scl1.squeeze()
** prod([mul(mat1, row1).shape[0], mul(mat1, row1).shape[1]])
** prod(
[
mul(mat1, row1).shape[0].astype("float64"),
mul(mat1, row1).shape[1].astype("float64"),
]
)
* mul(mat1, row1).prod(axis=None)
),
),
......@@ -3050,146 +3059,7 @@ class TestLocalSumProd:
f = function([mat], at_sum(-mat), mode=m0)
assert check_stack_trace(f, ops_to_check=[Sum])
class TestLocalReduce:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
)
def test_local_reduce_broadcast_all_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x)], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_all_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_some_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, None, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
order = f.maker.fgraph.toposort()
assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
node = [node for node in order if isinstance(node.op, CAReduce)][0]
op = node.op
assert isinstance(op, CAReduce)
# The leading broadcastable dimension has been dropped by the
# `local_reduce_broadcastable` rewrite. Now, summation is over
# the original `x`'s dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_reduce_broadcast_some_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x, axis=[0, 2])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_join(self):
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(at_max((vx, vy), 0), np.max((x, y), 0)),
(at_min((vx, vy), 0), np.min((x, y), 0)),
(at_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
f = function([], at_sum(at.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)
# Test a case that was bugged in a old PyTensor bug
f = function([], at_sum(at.stack([A, A]), axis=1), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
out = at_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out)
class TestLocalSumProdDimshuffle:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize")
def test_local_sum_div_dimshuffle(self):
def test_local_sum_of_div(self):
a = matrix("a")
b = vector("b")
c = tensor3("c")
......@@ -3242,7 +3112,7 @@ class TestLocalSumProdDimshuffle:
assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv)
f(a_val, b_val, c_val, d_val)
def test_local_prod_div_dimshuffle(self):
def test_local_prod_of_div(self):
a = matrix("a")
b = vector("b")
c = tensor3("c")
......@@ -3295,9 +3165,9 @@ class TestLocalSumProdDimshuffle:
# `FusionOptimizer` is included to make sure that `expected_outer_operator`
# remains the same for all rewrite modes.
mode_with_rewrite = default_mode.including(
"local_sum_prod_div_dimshuffle", "FusionOptimizer"
"local_sum_prod_of_mul_or_div", "FusionOptimizer"
)
mode_without_rewrite = default_mode.excluding("local_sum_prod_div_dimshuffle")
mode_without_rewrite = default_mode.excluding("local_sum_prod_of_mul_or_div")
# Numerical tests: tests whether the numerical values with and without
# rewrites are equal or not.
......@@ -3345,9 +3215,139 @@ class TestLocalSumProdDimshuffle:
g.maker.fgraph.toposort()[-1].op.scalar_op, expected_outer_operator[i]
)
# TODO:
# test_local_sum_prod_dimshuffle (a * b * c)
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
class TestLocalReduce:
def setup_method(self):
self.mode = get_default_mode().including(
"canonicalize", "specialize", "uncanonicalize", "local_max_and_argmax"
)
def test_local_reduce_broadcast_all_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x)], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_all_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_broadcast_some_0(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, None, 1))()
f = function([x], [fct(x, axis=[0, 1])], mode=self.mode)
order = f.maker.fgraph.toposort()
assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
node = [node for node in order if isinstance(node.op, CAReduce)][0]
op = node.op
assert isinstance(op, CAReduce)
# The leading broadcastable dimension has been dropped by the
# `local_reduce_broadcastable` rewrite. Now, summation is over
# the original `x`'s dimension 1.
assert node.inputs[0].ndim == 2, node
assert op.axis == (0,), op.axis
def test_local_reduce_broadcast_some_1(self):
for fct in [
at_sum,
at_all,
at_any,
prod,
at_max,
at_min,
]:
x = TensorType("int64", shape=(1, 1, 1))()
f = function([x], [fct(x, axis=[0, 2])], mode=self.mode)
assert not any(
isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort()
)
def test_local_reduce_join(self):
vx = matrix()
vy = matrix()
vz = matrix()
x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)
# Test different reduction scalar operation
for out, res in [
(at_max((vx, vy), 0), np.max((x, y), 0)),
(at_min((vx, vy), 0), np.min((x, y), 0)),
(at_sum((vx, vy, vz), 0), np.sum((x, y, z), 0)),
(prod((vx, vy, vz), 0), np.prod((x, y, z), 0)),
(prod((vx, vy.T, vz), 0), np.prod((x, y.T, z), 0)),
]:
f = function([vx, vy, vz], out, on_unused_input="ignore", mode=self.mode)
assert (f(x, y, z) == res).all(), out
topo = f.maker.fgraph.toposort()
assert len(topo) <= 2, out
assert isinstance(topo[-1].op, Elemwise), out
# Test different axis for the join and the reduction
# We must force the dtype, of otherwise, this tests will fail
# on 32 bit systems
A = shared(np.array([1, 2, 3, 4, 5], dtype="int64"))
f = function([], at_sum(at.stack([A, A]), axis=0), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert isinstance(topo[-1].op, Elemwise)
# Test a case that was bugged in a old PyTensor bug
f = function([], at_sum(at.stack([A, A]), axis=1), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# This case could be rewritten
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode)
utt.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], at_sum(at.concatenate((A, A), axis=1), axis=0), mode=self.mode)
utt.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
out = at_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out)
def test_local_useless_adds():
......@@ -3534,7 +3534,6 @@ def test_local_mul_exp_to_exp_add():
# e^x * e^y * e^z * e^w = e^(x+y+z+w)
op = expx * expy * expz * expw
f = function([x, y, z, w], op, mode)
pytensor.dprint(f)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论