提交 d34760d7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add dprint method to Nodes

上级 1117ea5e
......@@ -75,6 +75,18 @@ class Node(MetaObject):
"""
raise NotImplementedError()
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
class Apply(Node, Generic[OpType]):
"""A `Node` representing the application of an operation to inputs.
......
......@@ -31,6 +31,7 @@ from pytensor.graph.basic import (
)
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.printing import debugprint
from pytensor.tensor import constant
from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
......@@ -869,3 +870,10 @@ class TestTruncatedGraphInputs:
assert len(inspect.call_args_list) == len(
{a for ((a, b), kw) in inspect.call_args_list}
)
def test_dprint():
r1, r2 = MyVariable(1), MyVariable(2)
o1 = MyOp(r1, r2)
assert o1.dprint(file="str") == debugprint(o1, file="str")
assert o1.owner.dprint(file="str") == debugprint(o1.owner, file="str")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论