提交 086323fa authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Add some types to printing.py

上级 1a92165c
...@@ -1200,18 +1200,18 @@ default_colorCodes = { ...@@ -1200,18 +1200,18 @@ default_colorCodes = {
def pydotprint( def pydotprint(
fct, fct,
outfile=None, outfile: str | None = None,
compact=True, compact: bool = True,
format="png", format: str = "png",
with_ids=False, with_ids: bool = False,
high_contrast=True, high_contrast: bool = True,
cond_highlight=None, cond_highlight=None,
colorCodes=None, colorCodes: dict | None = None,
max_label_size=70, max_label_size: int = 70,
scan_graphs=False, scan_graphs: bool = False,
var_with_name_simple=False, var_with_name_simple: bool = False,
print_output_file=True, print_output_file: bool = True,
return_image=False, return_image: bool = False,
): ):
"""Print to a file the graph of a compiled pytensor function's ops. Supports """Print to a file the graph of a compiled pytensor function's ops. Supports
all pydot output formats, including png and svg. all pydot output formats, including png and svg.
...@@ -1676,7 +1676,9 @@ class _TagGenerator: ...@@ -1676,7 +1676,9 @@ class _TagGenerator:
return rval return rval
def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None): def min_informative_str(
obj, indent_level: int = 0, _prev_obs: dict | None = None, _tag_generator=None
) -> str:
""" """
Returns a string specifying to the user what obj is Returns a string specifying to the user what obj is
The string will print out as much of the graph as is needed The string will print out as much of the graph as is needed
...@@ -1776,7 +1778,7 @@ def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None ...@@ -1776,7 +1778,7 @@ def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None
return rval return rval
def var_descriptor(obj, _prev_obs=None, _tag_generator=None): def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> str:
""" """
Returns a string, with no endlines, fully specifying Returns a string, with no endlines, fully specifying
how a variable is computed. Does not include any memory how a variable is computed. Does not include any memory
...@@ -1832,7 +1834,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None): ...@@ -1832,7 +1834,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
return rval return rval
def position_independent_str(obj): def position_independent_str(obj) -> str:
if isinstance(obj, Variable): if isinstance(obj, Variable):
rval = "pytensor_var" rval = "pytensor_var"
rval += "{type=" + str(obj.type) + "}" rval += "{type=" + str(obj.type) + "}"
...@@ -1842,7 +1844,7 @@ def position_independent_str(obj): ...@@ -1842,7 +1844,7 @@ def position_independent_str(obj):
return rval return rval
def hex_digest(x): def hex_digest(x: np.ndarray) -> str:
""" """
Returns a short, mostly hexadecimal hash of a numpy ndarray Returns a short, mostly hexadecimal hash of a numpy ndarray
""" """
...@@ -1852,8 +1854,8 @@ def hex_digest(x): ...@@ -1852,8 +1854,8 @@ def hex_digest(x):
# because the buffer interface only exposes the raw data, not # because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged # any info about the semantics of how that data should be arranged
# into a tensor # into a tensor
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]" rval += "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]" rval += "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论