提交 c8908e5a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent Print from being constant folded

上级 58046078
...@@ -802,6 +802,9 @@ class Print(Op): ...@@ -802,6 +802,9 @@ class Print(Op):
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
def do_constant_folding(self, fgraph, node):
return False
class PrinterState(Scratchpad): class PrinterState(Scratchpad):
def __init__(self, props=None, **more_props): def __init__(self, props=None, **more_props):
......
...@@ -7,9 +7,12 @@ from io import StringIO ...@@ -7,9 +7,12 @@ from io import StringIO
import pytest import pytest
import aesara import aesara
from aesara.compile.mode import get_mode
from aesara.compile.ops import deep_copy_op
from aesara.printing import ( from aesara.printing import (
PatternPrinter, PatternPrinter,
PPrinter, PPrinter,
Print,
debugprint, debugprint,
default_printer, default_printer,
get_node_by_id, get_node_by_id,
...@@ -18,6 +21,7 @@ from aesara.printing import ( ...@@ -18,6 +21,7 @@ from aesara.printing import (
pydot_imported, pydot_imported,
pydotprint, pydotprint,
) )
from aesara.tensor import as_tensor_variable
from aesara.tensor.type import dmatrix, dvector, matrix from aesara.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
...@@ -401,3 +405,25 @@ def test_PatternPrinter(): ...@@ -401,3 +405,25 @@ def test_PatternPrinter():
res = pprint(o1) res = pprint(o1)
assert res == "|1 - 2|" assert res == "|1 - 2|"
def test_Print(capsys):
r"""Make sure that `Print` `Op`\s are present in compiled graphs with constant folding."""
x = as_tensor_variable(1.0) * as_tensor_variable(3.0)
print_op = Print("hello")
x_print = print_op(x)
# Just to be more sure that we'll have constant folding...
mode = get_mode("FAST_RUN").including("topo_constant_folding")
fn = aesara.function([], x_print, mode=mode)
nodes = fn.maker.fgraph.toposort()
assert len(nodes) == 2
assert nodes[0].op == print_op
assert nodes[1].op == deep_copy_op
fn()
stdout, stderr = capsys.readouterr()
assert "hello" in stdout
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论