提交 f2ad711f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement unconditional constant_folding rewrite

上级 a570dbfd
...@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp ...@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter, NodeRewriter,
RemovalNodeRewriter, RemovalNodeRewriter,
Rewriter, Rewriter,
...@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node): ...@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):
@node_rewriter(None) @node_rewriter(None)
def constant_folding(fgraph, node): def unconditional_constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
if not all(isinstance(inp, Constant) for inp in node.inputs): if not all(isinstance(inp, Constant) for inp in node.inputs):
return False return False
...@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node): ...@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
return rval return rval
topo_unconditional_constant_folding = in2out(
unconditional_constant_folding,
ignore_newtrees=True,
name="topo_unconditional_constant_folding",
# Not all Ops have a perform method, so we ignore failures to constant_fold
failure_callback=NodeProcessingGraphRewriter.warn_ignore,
)
@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
return unconditional_constant_folding.transform(fgraph, node)
topo_constant_folding = in2out( topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding" constant_folding, ignore_newtrees=True, name="topo_constant_folding"
) )
......
...@@ -12,7 +12,8 @@ from pytensor.compile.function import function ...@@ -12,7 +12,8 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations from pytensor.graph import Op
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -29,6 +30,7 @@ from pytensor.tensor.basic import ( ...@@ -29,6 +30,7 @@ from pytensor.tensor.basic import (
TensorFromScalar, TensorFromScalar,
as_tensor, as_tensor,
cast, cast,
constant,
join, join,
tile, tile,
) )
...@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc, local_merge_alloc,
local_useless_alloc, local_useless_alloc,
local_useless_elemwise, local_useless_elemwise,
topo_constant_folding,
topo_unconditional_constant_folding,
topological_fill_sink, topological_fill_sink,
) )
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
...@@ -742,7 +746,8 @@ class TestCastCast: ...@@ -742,7 +746,8 @@ class TestCastCast:
) or (len(topo) > 1) ) or (len(topo) > 1)
def test_constant_folding(): class TestConstantFolding:
def test_constant_folding(self):
# Test that constant folding get registered at fast_compile # Test that constant folding get registered at fast_compile
# An error removed that registration during the registration. # An error removed that registration during the registration.
x = dvector() x = dvector()
...@@ -762,13 +767,12 @@ def test_constant_folding(): ...@@ -762,13 +767,12 @@ def test_constant_folding():
assert len(topo) == 2 assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo) assert all(isinstance(n.op, DeepCopyOp) for n in topo)
@pytest.mark.xfail(
@pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. " reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.", "This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError, raises=AssertionError,
) )
def test_constant_get_stabilized(): def test_constant_get_stabilized(self):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites. # Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that # This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not. # caused inf values to appear when they should not.
...@@ -793,6 +797,42 @@ def test_constant_get_stabilized(): ...@@ -793,6 +797,42 @@ def test_constant_get_stabilized():
# When this error is fixed, the following line should be ok. # When this error is fixed, the following line should be ok.
assert f() == 800, f() assert f() == 800, f()
def test_unconditional(self):
x = pt.alloc(np.e, *(3, 5))
fg = FunctionGraph(outputs=[x], clone=False)
# Default constant folding doesn't apply to Alloc used as outputs
topo_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
# Unconditional constant folding does apply
topo_unconditional_constant_folding.apply(fg)
assert isinstance(fg.outputs[0], Constant)
np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e))
def test_unconditional_no_perform_method(self):
"""Test that errors are caught when the Op does not have a perform method."""
class OpNoPerform(Op):
itypes = [scalar(dtype="float64").type]
otypes = [scalar(dtype="float64").type]
def perform(self, *args, **kwargs):
raise NotImplementedError("This Op cannot be evaluated")
x = constant(np.array(5.0))
out = OpNoPerform()(x)
fg = FunctionGraph(outputs=[out], clone=False)
# Default constant_folding will raise
with pytest.raises(NotImplementedError):
topo_constant_folding.apply(fg)
# Unconditional constant folding will be silent
topo_unconditional_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
assert isinstance(fg.outputs[0].owner.op, OpNoPerform)
class TestLocalSwitchSink: class TestLocalSwitchSink:
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论