提交 df769f6c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add dprint shortcut to FunctionGraph and Function

上级 448e5582
...@@ -1097,6 +1097,18 @@ class Function: ...@@ -1097,6 +1097,18 @@ class Function:
# NOTE: sync was needed on old gpu backend # NOTE: sync was needed on old gpu backend
pass pass
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
# pickling/deepcopy support for Function # pickling/deepcopy support for Function
def _pickle_Function(f): def _pickle_Function(f):
......
...@@ -927,3 +927,15 @@ class FunctionGraph(MetaObject): ...@@ -927,3 +927,15 @@ class FunctionGraph(MetaObject):
return item in self.apply_nodes return item in self.apply_nodes
else: else:
raise TypeError() raise TypeError()
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
...@@ -16,6 +16,7 @@ from pytensor.graph.basic import Constant ...@@ -16,6 +16,7 @@ from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker from pytensor.link.vm import VMLinker
from pytensor.printing import debugprint
from pytensor.tensor.math import dot, tanh from pytensor.tensor.math import dot, tanh
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.type import ( from pytensor.tensor.type import (
...@@ -862,6 +863,12 @@ class TestFunction: ...@@ -862,6 +863,12 @@ class TestFunction:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
function([x], outputs={(1, "b"): x, 1.0: x**2}) function([x], outputs={(1, "b"): x, 1.0: x**2})
def test_dprint(self):
x = pt.scalar("x")
out = x + 1
f = function([x], out)
assert f.dprint(file="str") == debugprint(f, file="str")
class TestPicklefunction: class TestPicklefunction:
def test_deepcopy(self): def test_deepcopy(self):
......
...@@ -8,6 +8,7 @@ from pytensor.configdefaults import config ...@@ -8,6 +8,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import NominalVariable from pytensor.graph.basic import NominalVariable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from tests.graph.utils import ( from tests.graph.utils import (
MyConstant, MyConstant,
MyOp, MyOp,
...@@ -706,3 +707,9 @@ class TestFunctionGraph: ...@@ -706,3 +707,9 @@ class TestFunctionGraph:
assert nm2 not in fg.inputs assert nm2 not in fg.inputs
assert nm in fg.variables assert nm in fg.variables
assert nm2 in fg.variables assert nm2 in fg.variables
def test_dprint(self):
r1, r2 = MyVariable("x"), MyVariable("y")
o1 = op1(r1, r2)
fg = FunctionGraph([r1, r2], [o1], clone=False)
assert fg.dprint(file="str") == debugprint(fg, file="str")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论