提交 f2ad711f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement unconditional constant_folding rewrite

上级 a570dbfd
...@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp ...@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter, NodeRewriter,
RemovalNodeRewriter, RemovalNodeRewriter,
Rewriter, Rewriter,
...@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node): ...@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):
@node_rewriter(None) @node_rewriter(None)
def constant_folding(fgraph, node): def unconditional_constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
if not all(isinstance(inp, Constant) for inp in node.inputs): if not all(isinstance(inp, Constant) for inp in node.inputs):
return False return False
...@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node): ...@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
return rval return rval
topo_unconditional_constant_folding = in2out(
unconditional_constant_folding,
ignore_newtrees=True,
name="topo_unconditional_constant_folding",
# Not all Ops have a perform method, so we ignore failures to constant_fold
failure_callback=NodeProcessingGraphRewriter.warn_ignore,
)
@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
return unconditional_constant_folding.transform(fgraph, node)
topo_constant_folding = in2out( topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding" constant_folding, ignore_newtrees=True, name="topo_constant_folding"
) )
......
...@@ -12,7 +12,8 @@ from pytensor.compile.function import function ...@@ -12,7 +12,8 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations from pytensor.graph import Op
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...@@ -29,6 +30,7 @@ from pytensor.tensor.basic import ( ...@@ -29,6 +30,7 @@ from pytensor.tensor.basic import (
TensorFromScalar, TensorFromScalar,
as_tensor, as_tensor,
cast, cast,
constant,
join, join,
tile, tile,
) )
...@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc, local_merge_alloc,
local_useless_alloc, local_useless_alloc,
local_useless_elemwise, local_useless_elemwise,
topo_constant_folding,
topo_unconditional_constant_folding,
topological_fill_sink, topological_fill_sink,
) )
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
...@@ -742,56 +746,92 @@ class TestCastCast: ...@@ -742,56 +746,92 @@ class TestCastCast:
) or (len(topo) > 1) ) or (len(topo) > 1)
def test_constant_folding(): class TestConstantFolding:
# Test that constant folding get registered at fast_compile def test_constant_folding(self):
# An error removed that registration during the registration. # Test that constant folding get registered at fast_compile
x = dvector() # An error removed that registration during the registration.
mode = get_mode("FAST_COMPILE").excluding("fusion") x = dvector()
f = function([x], [x * 2, x + x], mode=mode) mode = get_mode("FAST_COMPILE").excluding("fusion")
topo = f.maker.fgraph.toposort() f = function([x], [x * 2, x + x], mode=mode)
assert len(topo) == 2 topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
x = pt.constant(3) # Test that we do not crash when constant folding elemwise scalar
assert x.ndim == 0 # as they should not generate c code.
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
x = pt.constant(3)
assert x.ndim == 0
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
@pytest.mark.xfail( @pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. " reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.", "This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError, raises=AssertionError,
) )
def test_constant_get_stabilized(): def test_constant_get_stabilized(self):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites. # Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that # This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not. # caused inf values to appear when they should not.
# We can't simply move the `constant_folding` rewrite to # We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to # specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue. # partially duplicate some canonicalize rewrites to fix this issue.
x2 = scalar() x2 = scalar()
y2 = log(1 + exp(x2)) y2 = log(1 + exp(x2))
mode = get_default_mode() mode = get_default_mode()
mode.check_isfinite = False mode.check_isfinite = False
f2 = function([x2], y2, mode=mode) f2 = function([x2], y2, mode=mode)
assert len(f2.maker.fgraph.toposort()) == 1 assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == softplus assert f2.maker.fgraph.toposort()[0].op == softplus
assert f2(800) == 800 assert f2(800) == 800
x = pt.as_tensor_variable(800) x = pt.as_tensor_variable(800)
y = log(1 + exp(x)) y = log(1 + exp(x))
f = function([], y, mode=mode) f = function([], y, mode=mode)
# When this error is fixed, the following line should be ok. # When this error is fixed, the following line should be ok.
assert f() == 800, f() assert f() == 800, f()
def test_unconditional(self):
x = pt.alloc(np.e, *(3, 5))
fg = FunctionGraph(outputs=[x], clone=False)
# Default constant folding doesn't apply to Alloc used as outputs
topo_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
# Unconditional constant folding does apply
topo_unconditional_constant_folding.apply(fg)
assert isinstance(fg.outputs[0], Constant)
np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e))
def test_unconditional_no_perform_method(self):
"""Test that errors are caught when the Op does not have a perform method."""
class OpNoPerform(Op):
itypes = [scalar(dtype="float64").type]
otypes = [scalar(dtype="float64").type]
def perform(self, *args, **kwargs):
raise NotImplementedError("This Op cannot be evaluated")
x = constant(np.array(5.0))
out = OpNoPerform()(x)
fg = FunctionGraph(outputs=[out], clone=False)
# Default constant_folding will raise
with pytest.raises(NotImplementedError):
topo_constant_folding.apply(fg)
# Unconditional constant folding will be silent
topo_unconditional_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
assert isinstance(fg.outputs[0].owner.op, OpNoPerform)
class TestLocalSwitchSink: class TestLocalSwitchSink:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论