提交 c49e395c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Merge rewrite for sum/prod of div with that of mul

上级 3b0a97b7
......@@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node):
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_of_mul(fgraph, node):
def local_sum_prod_of_mul_or_div(fgraph, node):
"""
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
......@@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node):
prod(a * X) -> (a ** size(X)) * prod(X)
It also applies to reduction of X / a,
but not a / X, as that would still require inverting every value in X before the reduction
TODO: In the case where not all axis overlap with broadcast dimensions,
consider introducing an outer reduction after factoring out the compatible reduced dimensions
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
[node_inps] = node.inputs
if not (node_inps.owner and node_inps.owner.op == mul):
if not node_inps.owner:
return None
inner_op = node_inps.owner.op
if not (inner_op == mul or inner_op == true_div):
return None
reduced_axes = node.op.axis
......@@ -1214,6 +1219,8 @@ def local_sum_prod_of_mul(fgraph, node):
reduced_axes = tuple(range(node_inps.type.ndim))
# Separate terms that can be moved out of the Sum/Prod and those that cannot
if inner_op == mul:
# Mul accepts arbitrary inputs, so we need to separate into two groups
outer_terms = []
inner_terms = []
for term in node_inps.owner.inputs:
......@@ -1237,6 +1244,16 @@ def local_sum_prod_of_mul(fgraph, node):
else:
inner_term = mul(*inner_terms)
else: # true_div
# We only care about removing the denominator out of the reduction
numerator, denominator = node_inps.owner.inputs
denominator_bcast = denominator.type.broadcastable
if all(denominator_bcast[i] for i in reduced_axes):
outer_term = denominator.squeeze(reduced_axes)
inner_term = numerator
else:
return None
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input
if isinstance(node.op, Prod) and inner_term:
......@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node):
)
outer_term = outer_term**n_reduced_elements
# Sum/Prod is useless, just return the outer_term
if not inner_term:
# Sum/Prod is useless, just return the outer_term
# (This can only happen for mul, not division)
new_out = outer_term
else:
reduced_inner_term = node.op(inner_term)
if inner_op == mul:
new_out = outer_term * reduced_inner_term
else:
new_out = reduced_inner_term / outer_term
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
copy_stack_trace(node.outputs, new_out)
......@@ -1510,99 +1531,6 @@ def local_useless_elemwise_comparison(fgraph, node):
return
@register_canonicalize
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_div_dimshuffle(fgraph, node):
"""
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'
or
prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
"""
# It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation or production.
if isinstance(node.op, (Sum, Prod)):
axis = node.op.axis
if axis is None:
axis = list(range(node.inputs[0].ndim))
node_input = node.inputs[0]
if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order
compatible_dims = []
incompatible_dims = []
for ax in axis:
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
compatible_dims.append(ax)
else:
incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
)
if len(compatible_dims) > 0:
optimized_dimshuffle_order = [
ax
for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x")
]
# Removing leading 'x' (since it will be done automatically)
while (
len(optimized_dimshuffle_order) > 0
and optimized_dimshuffle_order[0] == "x"
):
del optimized_dimshuffle_order[0]
# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order,
)(dimshuffle_input)
if isinstance(node.op, Sum):
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
if len(reordered_incompatible_dims) > 0:
rval = at_sum(rval, axis=reordered_incompatible_dims)
elif isinstance(node.op, Prod):
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
dtype = numerator.dtype
rval = true_div(
op_on_compatible_dims,
(
optimized_dimshuffle
** prod(
[
numerator.shape[ax].astype(dtype)
for ax in compatible_dims
]
)
),
)
if len(reordered_incompatible_dims) > 0:
rval = prod(rval, axis=reordered_incompatible_dims)
return [rval]
@register_canonicalize
@node_rewriter([Sum, Prod])
def local_sum_prod_all_to_none(fgraph, node):
......
......@@ -899,7 +899,7 @@ class TestFusion:
),
(fx, fy),
(fxv, fyv),
3,
2,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论