提交 086323fa authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Add some types to printing.py

上级 1a92165c
......@@ -1200,18 +1200,18 @@ default_colorCodes = {
def pydotprint(
fct,
outfile=None,
compact=True,
format="png",
with_ids=False,
high_contrast=True,
outfile: str | None = None,
compact: bool = True,
format: str = "png",
with_ids: bool = False,
high_contrast: bool = True,
cond_highlight=None,
colorCodes=None,
max_label_size=70,
scan_graphs=False,
var_with_name_simple=False,
print_output_file=True,
return_image=False,
colorCodes: dict | None = None,
max_label_size: int = 70,
scan_graphs: bool = False,
var_with_name_simple: bool = False,
print_output_file: bool = True,
return_image: bool = False,
):
"""Print to a file the graph of a compiled pytensor function's ops. Supports
all pydot output formats, including png and svg.
......@@ -1676,7 +1676,9 @@ class _TagGenerator:
return rval
def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None):
def min_informative_str(
obj, indent_level: int = 0, _prev_obs: dict | None = None, _tag_generator=None
) -> str:
"""
Returns a string specifying to the user what obj is
The string will print out as much of the graph as is needed
......@@ -1776,7 +1778,7 @@ def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None
return rval
def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> str:
"""
Returns a string, with no endlines, fully specifying
how a variable is computed. Does not include any memory
......@@ -1832,7 +1834,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
return rval
def position_independent_str(obj):
def position_independent_str(obj) -> str:
if isinstance(obj, Variable):
rval = "pytensor_var"
rval += "{type=" + str(obj.type) + "}"
......@@ -1842,7 +1844,7 @@ def position_independent_str(obj):
return rval
def hex_digest(x):
def hex_digest(x: np.ndarray) -> str:
"""
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
......@@ -1852,8 +1854,8 @@ def hex_digest(x):
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
# into a tensor
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
rval += "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
rval += "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论