Unverified 提交 4ee35881 authored 作者: Trey Wenger's avatar Trey Wenger 提交者: GitHub

Prevent `local_sum_make_vector` from introducing internal `float64` (#659)

上级 d175203b
......@@ -28,7 +28,7 @@ from typing import Union
import numpy as np
import pytensor.scalar.basic as ps
from pytensor import compile
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
......@@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node):
elements = array.owner.inputs
acc_dtype = node.op.acc_dtype
out_dtype = node.op.dtype
# Skip rewrite if it would add unnecessary float64 to the graph
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
return
if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
......
......@@ -12,7 +12,7 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations, vars_between
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -26,12 +26,12 @@ from pytensor.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
cast,
join,
tile,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import (
Sum,
add,
bitwise_and,
bitwise_or,
......@@ -1298,41 +1298,48 @@ def test_local_join_make_vector():
def test_local_sum_make_vector():
# To check that rewrite is applied, we must enforce dtype to
# allow rewrite to occur even if floatX != "float64"
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum()
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
output = mv(a, b, c).sum(dtype="float64")
rewrite_output = rewrite_graph(output)
expected_output = cast(
add(*[cast(value, "float64") for value in [a, b, c]]), dtype="float64"
)
assert equal_computations([expected_output], [rewrite_output])
# Check for empty sum
# Empty axes should return input vector since no sum is applied
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum(axis=[])
rewrite_output = rewrite_graph(output)
expected_output = mv(a, b, c)
assert equal_computations([expected_output], [rewrite_output])
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
# Check empty MakeVector
# Empty input should return 0
mv = MakeVector(config.floatX)
output = mv().sum()
rewrite_output = rewrite_graph(output)
expected_output = pt.as_tensor(0, dtype=config.floatX)
assert equal_computations([expected_output], [rewrite_output])
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
# Single element input should return element value
a = scalars("a")
mv = MakeVector(config.floatX)
output = mv(a).sum()
rewrite_output = rewrite_graph(output)
expected_output = cast(a, config.floatX)
assert equal_computations([expected_output], [rewrite_output])
output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
# This is a regression test for #653. Ensure that rewrite is NOT
# applied when user requests float32
with config.change_flags(floatX="float32", warn_float64="raise"):
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum()
rewrite_output = rewrite_graph(output)
assert equal_computations([output], [rewrite_output])
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论