提交 f2ad711f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement unconditional constant_folding rewrite

上级 a570dbfd
......@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
RemovalNodeRewriter,
Rewriter,
......@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):
@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
def unconditional_constant_folding(fgraph, node):
if not all(isinstance(inp, Constant) for inp in node.inputs):
return False
......@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
return rval
topo_unconditional_constant_folding = in2out(
unconditional_constant_folding,
ignore_newtrees=True,
name="topo_unconditional_constant_folding",
# Not all Ops have a perform method, so we ignore failures to constant_fold
failure_callback=NodeProcessingGraphRewriter.warn_ignore,
)
@node_rewriter(None)
def constant_folding(fgraph, node):
if not node.op.do_constant_folding(fgraph, node):
return False
return unconditional_constant_folding.transform(fgraph, node)
topo_constant_folding = in2out(
constant_folding, ignore_newtrees=True, name="topo_constant_folding"
)
......
......@@ -12,7 +12,8 @@ from pytensor.compile.function import function
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph import Op
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -29,6 +30,7 @@ from pytensor.tensor.basic import (
TensorFromScalar,
as_tensor,
cast,
constant,
join,
tile,
)
......@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc,
local_useless_alloc,
local_useless_elemwise,
topo_constant_folding,
topo_unconditional_constant_folding,
topological_fill_sink,
)
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
......@@ -742,56 +746,92 @@ class TestCastCast:
) or (len(topo) > 1)
def test_constant_folding():
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x = dvector()
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([x], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
class TestConstantFolding:
def test_constant_folding(self):
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x = dvector()
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([x], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
x = pt.constant(3)
assert x.ndim == 0
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
x = pt.constant(3)
assert x.ndim == 0
mode = get_mode("FAST_COMPILE").excluding("fusion")
f = function([], [x * 2, x + x], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert all(isinstance(n.op, DeepCopyOp) for n in topo)
@pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError,
)
def test_constant_get_stabilized():
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.
@pytest.mark.xfail(
reason="PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504.",
raises=AssertionError,
)
def test_constant_get_stabilized(self):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.
# We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue.
# We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue.
x2 = scalar()
y2 = log(1 + exp(x2))
mode = get_default_mode()
mode.check_isfinite = False
f2 = function([x2], y2, mode=mode)
assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == softplus
assert f2(800) == 800
x = pt.as_tensor_variable(800)
y = log(1 + exp(x))
f = function([], y, mode=mode)
# When this error is fixed, the following line should be ok.
assert f() == 800, f()
x2 = scalar()
y2 = log(1 + exp(x2))
mode = get_default_mode()
mode.check_isfinite = False
f2 = function([x2], y2, mode=mode)
assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == softplus
assert f2(800) == 800
x = pt.as_tensor_variable(800)
y = log(1 + exp(x))
f = function([], y, mode=mode)
# When this error is fixed, the following line should be ok.
assert f() == 800, f()
def test_unconditional(self):
x = pt.alloc(np.e, *(3, 5))
fg = FunctionGraph(outputs=[x], clone=False)
# Default constant folding doesn't apply to Alloc used as outputs
topo_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
# Unconditional constant folding does apply
topo_unconditional_constant_folding.apply(fg)
assert isinstance(fg.outputs[0], Constant)
np.testing.assert_allclose(fg.outputs[0].data, np.full((3, 5), np.e))
def test_unconditional_no_perform_method(self):
"""Test that errors are caught when the Op does not have a perform method."""
class OpNoPerform(Op):
itypes = [scalar(dtype="float64").type]
otypes = [scalar(dtype="float64").type]
def perform(self, *args, **kwargs):
raise NotImplementedError("This Op cannot be evaluated")
x = constant(np.array(5.0))
out = OpNoPerform()(x)
fg = FunctionGraph(outputs=[out], clone=False)
# Default constant_folding will raise
with pytest.raises(NotImplementedError):
topo_constant_folding.apply(fg)
# Unconditional constant folding will be silent
topo_unconditional_constant_folding.apply(fg)
assert not isinstance(fg.outputs[0], Constant)
assert isinstance(fg.outputs[0].owner.op, OpNoPerform)
class TestLocalSwitchSink:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论