Unverified 提交 72184319 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Add rewrite for Sum(MakeVector) (#346)

上级 981be2a6
......@@ -43,6 +43,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
from pytensor.tensor.math import Sum, add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i
......@@ -956,6 +957,41 @@ def local_join_make_vector(fgraph, node):
return [ret]
@register_specialize
@register_canonicalize
@register_useless
@node_rewriter([Sum])
def local_sum_make_vector(fgraph, node):
"""A sum of a MakeVector node is just the sum of the elements."""
(array,) = node.inputs
if array.owner is None:
return
if not isinstance(array.owner.op, MakeVector):
return
if node.op.axis == ():
return [array]
# If this is not the case the sum is invalid
assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,)
elements = array.owner.inputs
acc_dtype = node.op.acc_dtype
out_dtype = node.op.dtype
if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
element_sum = cast(elements[0], out_dtype)
else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)
return [element_sum]
@register_useless("local_remove_switch_const_cond")
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
@register_specialize
......
......@@ -12,7 +12,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import equal_computations, vars_between
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -31,6 +31,7 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import (
Sum,
add,
bitwise_and,
bitwise_or,
......@@ -1300,6 +1301,44 @@ def test_local_join_make_vector():
assert check_stack_trace(f, ops_to_check="all")
def test_local_sum_make_vector():
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum()
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
# Check for empty sum
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum(axis=[])
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
# Check empty MakeVector
mv = MakeVector(config.floatX)
output = mv().sum()
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
mv = MakeVector(config.floatX)
output = mv(a).sum()
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
@pytest.mark.parametrize(
"dtype",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论