提交 df769f6c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add dprint shortcut to FunctionGraph and Function

上级 448e5582
......@@ -1097,6 +1097,18 @@ class Function:
# NOTE: sync was needed on old gpu backend
pass
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
# pickling/deepcopy support for Function
def _pickle_Function(f):
......
......@@ -927,3 +927,15 @@ class FunctionGraph(MetaObject):
return item in self.apply_nodes
else:
raise TypeError()
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
......@@ -16,6 +16,7 @@ from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
from pytensor.printing import debugprint
from pytensor.tensor.math import dot, tanh
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.type import (
......@@ -862,6 +863,12 @@ class TestFunction:
with pytest.raises(AssertionError):
function([x], outputs={(1, "b"): x, 1.0: x**2})
def test_dprint(self):
x = pt.scalar("x")
out = x + 1
f = function([x], out)
assert f.dprint(file="str") == debugprint(f, file="str")
class TestPicklefunction:
def test_deepcopy(self):
......
......@@ -8,6 +8,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import NominalVariable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from tests.graph.utils import (
MyConstant,
MyOp,
......@@ -706,3 +707,9 @@ class TestFunctionGraph:
assert nm2 not in fg.inputs
assert nm in fg.variables
assert nm2 in fg.variables
def test_dprint(self):
r1, r2 = MyVariable("x"), MyVariable("y")
o1 = op1(r1, r2)
fg = FunctionGraph([r1, r2], [o1], clone=False)
assert fg.dprint(file="str") == debugprint(fg, file="str")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论