提交 d34760d7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add dprint method to Nodes

上级 1117ea5e
...@@ -75,6 +75,18 @@ class Node(MetaObject): ...@@ -75,6 +75,18 @@ class Node(MetaObject):
""" """
raise NotImplementedError() raise NotImplementedError()
def dprint(self, **kwargs):
"""Debug print itself
Parameters
----------
kwargs:
Optional keyword arguments to pass to debugprint function.
"""
from pytensor.printing import debugprint
return debugprint(self, **kwargs)
class Apply(Node, Generic[OpType]): class Apply(Node, Generic[OpType]):
"""A `Node` representing the application of an operation to inputs. """A `Node` representing the application of an operation to inputs.
......
...@@ -31,6 +31,7 @@ from pytensor.graph.basic import ( ...@@ -31,6 +31,7 @@ from pytensor.graph.basic import (
) )
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.printing import debugprint
from pytensor.tensor import constant from pytensor.tensor import constant
from pytensor.tensor.math import max_and_argmax from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
...@@ -869,3 +870,10 @@ class TestTruncatedGraphInputs: ...@@ -869,3 +870,10 @@ class TestTruncatedGraphInputs:
assert len(inspect.call_args_list) == len( assert len(inspect.call_args_list) == len(
{a for ((a, b), kw) in inspect.call_args_list} {a for ((a, b), kw) in inspect.call_args_list}
) )
def test_dprint():
r1, r2 = MyVariable(1), MyVariable(2)
o1 = MyOp(r1, r2)
assert o1.dprint(file="str") == debugprint(o1, file="str")
assert o1.owner.dprint(file="str") == debugprint(o1.owner, file="str")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论