提交 f2ad711f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement unconditional constant_folding rewrite

上级 a570dbfd
......@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
RemovalNodeRewriter,
Rewriter,
......@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):
@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
def unconditional_constant_folding(fgraph, node):
if not all(isinstance(inp, Constant) for inp in node.inputs):
return False
......@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
return rval
topo_unconditional_constant_folding = in2out(
unconditional_constant_folding,
ignore_newtrees=True,
name="topo_unconditional_constant_folding",
# Not all Ops have a perform method, so we ignore failures to constant_fold
failure_callback=NodeProcessingGraphRewriter.warn_ignore,
)
@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
return unconditional_constant_folding.transform(fgraph, node)
topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding"
)
......
......@@ -12,7 +12,8 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph import Op
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -29,6 +30,7 @@ from pytensor.tensor.basic import (
TensorFromScalar,
as_tensor,
cast,
constant,
join,
tile,
)
......@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc,
local_useless_alloc,
local_useless_elemwise,
topo_constant_folding,
topo_unconditional_constant_folding,
topological_fill_sink,
)
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
......@@ -742,7 +746,8 @@ class TestCastCast:
) or (len(topo) > 1)
def test_constant_folding():
class TestConstantFolding:
def test_constant_folding(self):
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x = dvector()
......@@ -762,13 +767,12 @@ def test_constant_folding():
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
@pytest.mark.xfail(
@pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError,
)
def test_constant_get_stabilized():
)
def test_constant_get_stabilized(self):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.
......@@ -793,6 +797,42 @@ def test_constant_get_stabilized():
# When this error is fixed, the following line should be ok.
assert f() == 800, f()
def test_unconditional(self):
x = pt.alloc(np.e, *(3, 5))
fg = FunctionGraph(outputs=[x], clone=False)
# Default constant folding doesn't apply to Alloc used as outputs
topo_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
# Unconditional constant folding does apply
topo_unconditional_constant_folding.apply(fg)
assert isinstance(fg.outputs[0], Constant)
np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e))
def test_unconditional_no_perform_method(self):
"""Test that errors are caught when the Op does not have a perform method."""
class OpNoPerform(Op):
itypes = [scalar(dtype="float64").type]
otypes = [scalar(dtype="float64").type]
def perform(self, *args, **kwargs):
raise NotImplementedError("This Op cannot be evaluated")
x = constant(np.array(5.0))
out = OpNoPerform()(x)
fg = FunctionGraph(outputs=[out], clone=False)
# Default constant_folding will raise
with pytest.raises(NotImplementedError):
topo_constant_folding.apply(fg)
# Unconditional constant folding will be silent
topo_unconditional_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
assert isinstance(fg.outputs[0].owner.op, OpNoPerform)
class TestLocalSwitchSink:
def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论