提交 b2c62589 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize and rename `local_reduce_chain`

上级 5b9c07ec
...@@ -100,7 +100,11 @@ from pytensor.tensor.type import ( ...@@ -100,7 +100,11 @@ from pytensor.tensor.type import (
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
) )
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
...@@ -1575,42 +1579,48 @@ def local_sum_prod_all_to_none(fgraph, node): ...@@ -1575,42 +1579,48 @@ def local_sum_prod_all_to_none(fgraph, node):
@register_canonicalize @register_canonicalize
@node_rewriter([Sum, Prod]) @node_rewriter([CAReduce])
def local_op_of_op(fgraph, node): def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
""" """
Prod(Prod()) -> single Prod()
or
Sum(Sum()) -> single Sum() Sum(Sum()) -> single Sum()
or any CAReduce(Careduce(x)) of the same type
""" """
op_type = Sum if isinstance(node.op, Sum) else Prod [inner_reduce] = node.inputs
(node_inps,) = node.inputs if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)):
out_dtype = node.op.dtype return None
# This is done to make sure the rewrite doesn't affect other
# computations. # Don't apply rewrite if inner_reduce is used elsewhere
if len(fgraph.clients[node_inps]) == 1: if len(fgraph.clients[inner_reduce]) > 1:
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)): return None
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it # Check if CAReduces have the same scalar op
if node_inps.owner.op.axis is None or node.op.axis is None: outer_op: CAReduce = node.op
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])] inner_op = inner_reduce.owner.op
# figure out which axes were in the original sum if outer_op.scalar_op != inner_op.scalar_op:
newaxis = list(node_inps.owner.op.axis) return None
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(
list(node_inps.owner.op.axis) + list(node.op.axis)
)
combined = op_type(newaxis, dtype=out_dtype) outer_axis = outer_op.axis
return [combined(node_inps.owner.inputs[0])] inner_axis = inner_op.axis
[x] = inner_reduce.owner.inputs
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if outer_axis is None or inner_axis is None:
return [outer_op.clone(axis=None)(x)]
# Merge axis
newaxis = list(inner_axis)
for i in outer_axis:
new_i = i
for ii in inner_axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(inner_axis) + len(outer_axis)
return [outer_op.clone(axis=sorted(newaxis))(x)]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论