提交 c49e395c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Merge rewrite for sum/prod of div with that of mul

上级 3b0a97b7
...@@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node): ...@@ -1190,7 +1190,7 @@ def local_neg_to_mul(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_of_mul(fgraph, node): def local_sum_prod_of_mul_or_div(fgraph, node):
""" """
sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions sum(a * X) -> a * sum(X), when a is broadcasted along the sum dimensions
...@@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1198,15 +1198,20 @@ def local_sum_prod_of_mul(fgraph, node):
prod(a * X) -> (a ** size(X)) * prod(X) prod(a * X) -> (a ** size(X)) * prod(X)
It also applies to reduction of X / a,
but not a / X, as that would still require inverting every value in X before the reduction
TODO: In the case where not all axis overlap with broadcast dimensions, TODO: In the case where not all axis overlap with broadcast dimensions,
consider introducing an outer reduction after factoring out the compatible reduced dimensions consider introducing an outer reduction after factoring out the compatible reduced dimensions
E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1) E.g. sum(arange(5) * X, axis=(0, 2)) -> sum(sum(X, axis=0) * arange(5), axis=1)
""" """
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
[node_inps] = node.inputs [node_inps] = node.inputs
if not (node_inps.owner and node_inps.owner.op == mul): if not node_inps.owner:
return None
inner_op = node_inps.owner.op
if not (inner_op == mul or inner_op == true_div):
return None return None
reduced_axes = node.op.axis reduced_axes = node.op.axis
...@@ -1214,6 +1219,8 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1214,6 +1219,8 @@ def local_sum_prod_of_mul(fgraph, node):
reduced_axes = tuple(range(node_inps.type.ndim)) reduced_axes = tuple(range(node_inps.type.ndim))
# Separate terms that can be moved out of the Sum/Prod and those that cannot # Separate terms that can be moved out of the Sum/Prod and those that cannot
if inner_op == mul:
# Mul accepts arbitrary inputs, so we need to separate into two groups
outer_terms = [] outer_terms = []
inner_terms = [] inner_terms = []
for term in node_inps.owner.inputs: for term in node_inps.owner.inputs:
...@@ -1237,6 +1244,16 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1237,6 +1244,16 @@ def local_sum_prod_of_mul(fgraph, node):
else: else:
inner_term = mul(*inner_terms) inner_term = mul(*inner_terms)
else: # true_div
# We only care about removing the denominator out of the reduction
numerator, denominator = node_inps.owner.inputs
denominator_bcast = denominator.type.broadcastable
if all(denominator_bcast[i] for i in reduced_axes):
outer_term = denominator.squeeze(reduced_axes)
inner_term = numerator
else:
return None
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements # If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input # that were contracted in the input
if isinstance(node.op, Prod) and inner_term: if isinstance(node.op, Prod) and inner_term:
...@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node): ...@@ -1246,12 +1263,16 @@ def local_sum_prod_of_mul(fgraph, node):
) )
outer_term = outer_term**n_reduced_elements outer_term = outer_term**n_reduced_elements
# Sum/Prod is useless, just return the outer_term
if not inner_term: if not inner_term:
# Sum/Prod is useless, just return the outer_term
# (This can only happen for mul, not division)
new_out = outer_term new_out = outer_term
else: else:
reduced_inner_term = node.op(inner_term) reduced_inner_term = node.op(inner_term)
if inner_op == mul:
new_out = outer_term * reduced_inner_term new_out = outer_term * reduced_inner_term
else:
new_out = reduced_inner_term / outer_term
copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term]) copy_stack_trace(node.outputs, [inner_term, reduced_inner_term, outer_term])
copy_stack_trace(node.outputs, new_out) copy_stack_trace(node.outputs, new_out)
...@@ -1510,99 +1531,6 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1510,99 +1531,6 @@ def local_useless_elemwise_comparison(fgraph, node):
return return
@register_canonicalize
@register_specialize
@node_rewriter([Sum, Prod])
def local_sum_prod_div_dimshuffle(fgraph, node):
"""
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'
or
prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
"""
# It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation or production.
if isinstance(node.op, (Sum, Prod)):
axis = node.op.axis
if axis is None:
axis = list(range(node.inputs[0].ndim))
node_input = node.inputs[0]
if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order
compatible_dims = []
incompatible_dims = []
for ax in axis:
if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x":
compatible_dims.append(ax)
else:
incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax)
)
if len(compatible_dims) > 0:
optimized_dimshuffle_order = [
ax
for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x")
]
# Removing leading 'x' (since it will be done automatically)
while (
len(optimized_dimshuffle_order) > 0
and optimized_dimshuffle_order[0] == "x"
):
del optimized_dimshuffle_order[0]
# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order,
)(dimshuffle_input)
if isinstance(node.op, Sum):
op_on_compatible_dims = at_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
if len(reordered_incompatible_dims) > 0:
rval = at_sum(rval, axis=reordered_incompatible_dims)
elif isinstance(node.op, Prod):
op_on_compatible_dims = prod(numerator, axis=compatible_dims)
dtype = numerator.dtype
rval = true_div(
op_on_compatible_dims,
(
optimized_dimshuffle
** prod(
[
numerator.shape[ax].astype(dtype)
for ax in compatible_dims
]
)
),
)
if len(reordered_incompatible_dims) > 0:
rval = prod(rval, axis=reordered_incompatible_dims)
return [rval]
@register_canonicalize @register_canonicalize
@node_rewriter([Sum, Prod]) @node_rewriter([Sum, Prod])
def local_sum_prod_all_to_none(fgraph, node): def local_sum_prod_all_to_none(fgraph, node):
......
...@@ -899,7 +899,7 @@ class TestFusion: ...@@ -899,7 +899,7 @@ class TestFusion:
), ),
(fx, fy), (fx, fy),
(fxv, fyv), (fxv, fyv),
3, 2,
( (
np.sum(-((fxv - fyv) ** 2) / 2), np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv), -(fxv - fyv),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论