提交 724561b2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make aesara.printing.debugprint handle arbitrary HasInnerGraph Ops

上级 e3886d2b
......@@ -21,7 +21,7 @@ from aesara.compile.profiling import ProfileStats
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, StorageMapType
from aesara.graph.op import HasInnerGraph, Op, StorageMapType
from aesara.graph.utils import Scratchpad
......@@ -149,29 +149,32 @@ def debugprint(
A string representing the printed graph, if `file` is a string, else `file`.
"""
from aesara.scan.op import Scan
if not isinstance(depth, int):
raise Exception("depth parameter must be an int")
if file == "str":
_file = StringIO()
elif file is None:
_file = sys.stdout
else:
_file = file
if done is None:
done = dict()
if used_ids is None:
used_ids = dict()
used_ids = dict()
results_to_print = []
profile_list = []
order = [] # Toposort
smap = [] # storage_map
if isinstance(obj, (list, tuple, set)):
lobj = obj
else:
lobj = [obj]
for obj in lobj:
if isinstance(obj, Variable):
results_to_print.append(obj)
......@@ -206,9 +209,9 @@ def debugprint(
smap.append(None)
order.append(None)
else:
raise TypeError("debugprint cannot print an object of this type", obj)
raise TypeError(f"debugprint cannot print an object type {type(obj)}")
scan_ops = []
inner_graph_ops = []
if any([p for p in profile_list if p is not None and p.fct_callcount > 0]):
print(
"""
......@@ -232,9 +235,8 @@ N.B.:
)
for r, p, s, o in zip(results_to_print, profile_list, smap, order):
# Add the parent scan op to the list as well
if hasattr(r.owner, "op") and isinstance(r.owner.op, Scan):
scan_ops.append(r)
if hasattr(r.owner, "op") and isinstance(r.owner.op, HasInnerGraph):
inner_graph_ops.append(r)
_debugprint(
r,
......@@ -244,36 +246,48 @@ N.B.:
file=_file,
order=o,
ids=ids,
scan_ops=scan_ops,
inner_graph_ops=inner_graph_ops,
stop_on_name=stop_on_name,
profile=p,
smap=s,
used_ids=used_ids,
)
if len(scan_ops) > 0:
if len(inner_graph_ops) > 0:
print("", file=_file)
new_prefix = " >"
new_prefix_child = " >"
print("Inner graphs of the scan ops:", file=_file)
print("Inner graphs:", file=_file)
for s in inner_graph_ops:
for s in scan_ops:
# prepare a dict which maps the scan op's inner inputs
# to its outer inputs.
if hasattr(s.owner.op, "_fn"):
# This is a work-around to maintain backward compatibility
# (e.g. to only print inner graphs that have been compiled through
# a call to `Op.prepare_node`)
inner_fn = getattr(s.owner.op, "_fn", None)
if inner_fn:
# If the op was compiled, print the optimized version.
inner_inputs = s.owner.op.fn.maker.fgraph.inputs
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = s.owner.op.inputs
inner_inputs = s.owner.op.inner_inputs
inner_outputs = s.owner.op.inner_outputs
outer_inputs = s.owner.inputs
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
].items()
}
if hasattr(s.owner.op, "get_oinp_iinp_iout_oout_mappings"):
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
].items()
}
else:
inner_to_outer_inputs = None
print("", file=_file)
_debugprint(
s,
depth=depth,
......@@ -281,21 +295,16 @@ N.B.:
print_type=print_type,
file=_file,
ids=ids,
scan_ops=scan_ops,
inner_graph_ops=inner_graph_ops,
stop_on_name=stop_on_name,
scan_inner_to_outer_inputs=inner_to_outer_inputs,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
)
if hasattr(s.owner.op, "_fn"):
# If the op was compiled, print the optimized version.
outputs = s.owner.op.fn.maker.fgraph.outputs
else:
outputs = s.owner.op.outputs
for idx, i in enumerate(outputs):
if hasattr(i, "owner") and hasattr(i.owner, "op"):
if isinstance(i.owner.op, Scan):
scan_ops.append(i)
for idx, i in enumerate(inner_outputs):
if isinstance(getattr(i.owner, "op", None), HasInnerGraph):
inner_graph_ops.append(i)
_debugprint(
r=i,
......@@ -307,8 +316,8 @@ N.B.:
ids=ids,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
scan_ops=scan_ops,
scan_inner_to_outer_inputs=inner_to_outer_inputs,
inner_graph_ops=inner_graph_ops,
inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids,
)
......@@ -333,9 +342,9 @@ def _debugprint(
ids: str = "CHAR",
stop_on_name: bool = False,
prefix_child: Optional[str] = None,
scan_ops: Optional[List[Variable]] = None,
inner_graph_ops: Optional[List[Variable]] = None,
profile: Optional[ProfileStats] = None,
scan_inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
smap: Optional[StorageMapType] = None,
used_ids: Optional[Dict[Variable, str]] = None,
) -> IOBase:
......@@ -373,11 +382,10 @@ def _debugprint(
stop_on_name
When ``True``, if a node in the graph has a name, we don't print anything
below it.
scan_ops
`Scan` `Op`\s in the graph will be added inside this list for later
printing purposes.
scan_inner_to_outer_inputs
A dictionary mapping a `Scan` `Op`'s inner-inputs to its outer-inputs.
inner_graph_ops
A list of `Op`\s with inner graphs.
inner_to_outer_inputs
A dictionary mapping an `Op`'s inner-inputs to its outer-inputs.
smap
``None`` or the ``storage_map`` when printing an Aesara function.
used_ids
......@@ -394,8 +402,8 @@ def _debugprint(
if done is None:
done = dict()
if scan_ops is None:
scan_ops = []
if inner_graph_ops is None:
inner_graph_ops = []
if print_type:
type_str = f" <{r.type}>"
......@@ -518,10 +526,8 @@ def _debugprint(
new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"):
from aesara.scan.op import Scan
if isinstance(i.owner.op, Scan):
scan_ops.append(i)
if isinstance(i.owner.op, HasInnerGraph):
inner_graph_ops.append(i)
_debugprint(
i,
......@@ -534,17 +540,17 @@ def _debugprint(
ids=ids,
stop_on_name=stop_on_name,
prefix_child=new_prefix_child,
scan_ops=scan_ops,
inner_graph_ops=inner_graph_ops,
profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs,
inner_to_outer_inputs=inner_to_outer_inputs,
smap=smap,
used_ids=used_ids,
)
else:
if scan_inner_to_outer_inputs is not None and r in scan_inner_to_outer_inputs:
if inner_to_outer_inputs is not None and r in inner_to_outer_inputs:
id_str = get_id_str(r)
outer_r = scan_inner_to_outer_inputs[r]
outer_r = inner_to_outer_inputs[r]
if hasattr(outer_r.owner, "op"):
outer_id_str = get_id_str(outer_r.owner)
......@@ -665,8 +671,8 @@ class OperatorPrinter:
node = output.owner
if node is None:
raise TypeError(
"operator %s cannot represent a variable that is "
"not the result of an operation" % self.operator
f"operator {self.operator} cannot represent a variable that is "
"not the result of an operation"
)
# Precedence seems to be buggy, see #249
......@@ -720,8 +726,8 @@ class PatternPrinter:
node = output.owner
if node is None:
raise TypeError(
"Patterns %s cannot represent a variable that is "
"not the result of an operation" % self.patterns
f"Patterns {self.patterns} cannot represent a variable that is "
"not the result of an operation"
)
idx = node.outputs.index(output)
pattern, precedences = self.patterns[idx]
......@@ -760,8 +766,8 @@ class FunctionPrinter:
node = output.owner
if node is None:
raise TypeError(
"function %s cannot represent a variable that is "
"not the result of an operation" % self.names
f"function {self.names} cannot represent a variable that is "
"not the result of an operation"
)
idx = node.outputs.index(output)
name = self.names[idx]
......@@ -1107,10 +1113,9 @@ def pydotprint(
fgraph = fct
if not pydot_imported:
raise RuntimeError(
"Failed to import pydot. You must install graphviz"
" and either pydot or pydot-ng for "
"`pydotprint` to work.",
pydot_imported_msg,
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)
g = pd.Dot()
......
......@@ -11,6 +11,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph
from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer
from aesara.tensor.math import dot, exp
......@@ -433,3 +434,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
f = op(y)
grad_f = grad(f, y)
assert grad_f.tag.test_value is not None
def test_debugprint():
x, y, z = matrices("xyz")
e = x + y * z
op = OpFromGraph([x, y, z], [e])
out = op(x, y, z)
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A] ''
|x [id B]
|y [id C]
|z [id D]
Inner graphs:
OpFromGraph{inline=False} [id A] ''
>Elemwise{add,no_inplace} [id E] ''
> |x [id F]
> |Elemwise{mul,no_inplace} [id G] ''
> |y [id H]
> |z [id I]
"""
for truth, out in zip(exp_res.split("\n"), lines):
assert truth.strip() == out.strip()
import numpy as np
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type
......@@ -108,3 +108,33 @@ op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
op_cast_type2 = MyOpCastType2("OpCastType2")
class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [inputs[0].type()]
return Apply(self, list(inputs), outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
@property
def fn(self):
raise NotImplementedError("No Python implementation available.")
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
......@@ -51,7 +51,7 @@ def test_scan_debugprint1():
| |ScalarConstant{1} [id U]
|ScalarConstant{-1} [id V]
Inner graphs of the scan ops:
Inner graphs:
for{cpu,scan_fn} [id C] ''
>Elemwise{mul,no_inplace} [id W] ''
......@@ -111,7 +111,7 @@ def test_scan_debugprint2():
|Elemwise{scalar_minimum,no_inplace} [id C] ''
|x [id W]
Inner graphs of the scan ops:
Inner graphs:
for{cpu,scan_fn} [id B] ''
>Elemwise{mul,no_inplace} [id X] ''
......@@ -124,7 +124,6 @@ def test_scan_debugprint2():
assert truth.strip() == out.strip()
@aesara.config.change_flags(optimizer_verbose=True)
def test_scan_debugprint3():
coefficients = dvector("coefficients")
max_coefficients_supported = 10
......@@ -192,7 +191,7 @@ def test_scan_debugprint3():
|A [id W]
|k [id X]
Inner graphs of the scan ops:
Inner graphs:
for{cpu,scan_fn} [id B] ''
>Elemwise{mul,no_inplace} [id Y] ''
......@@ -293,7 +292,7 @@ def test_scan_debugprint4():
|for{cpu,scan_fn}.1 [id C] ''
|ScalarConstant{2} [id BA]
Inner graphs of the scan ops:
Inner graphs:
for{cpu,scan_fn}.0 [id C] ''
>Elemwise{add,no_inplace} [id BB] ''
......@@ -412,7 +411,7 @@ def test_scan_debugprint5():
| |A [id P]
|ScalarConstant{-1} [id CL]
Inner graphs of the scan ops:
Inner graphs:
for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id CM] ''
......
......@@ -2015,7 +2015,7 @@ class TestLocalUselessElemwiseComparison:
| |Subtensor{int64} [id C] ''
|Y [id K]
Inner graphs of the scan ops:
Inner graphs:
for{cpu,scan_fn} [id B] ''
>Sum{acc_dtype=float64} [id L] ''
......@@ -2050,7 +2050,7 @@ class TestLocalUselessElemwiseComparison:
| |Shape_i{0} [id C] <TensorType(int64, scalar)> '' 0
|Y [id M] <TensorType(float64, vector)>
Inner graphs of the scan ops:
Inner graphs:
for{cpu,scan_fn} [id B] <TensorType(float64, vector)> ''
>Sum{acc_dtype=float64} [id N] <TensorType(float64, scalar)> ''
......
......@@ -15,6 +15,7 @@ from aesara.printing import (
pydotprint,
)
from aesara.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
@pytest.mark.skipif(not pydot_imported, reason="pydot not available")
......@@ -109,6 +110,9 @@ def test_min_informative_str():
def test_debugprint():
with pytest.raises(TypeError):
debugprint("blah")
A = matrix(name="A")
B = matrix(name="B")
C = A + B
......@@ -277,3 +281,66 @@ def test_pprint():
x = dvector()
y = x[1]
assert pp(y) == "<TensorType(float64, vector)>[ScalarConstant{1}]"
def test_debugprint_inner_graph():
r1, r2 = MyVariable("1"), MyVariable("2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("4")
igo_in_2 = MyVariable("5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
r3, r4 = MyVariable("3"), MyVariable("4")
out = igo(r3, r4)
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
|3 [id B]
|4 [id C]
Inner graphs:
MyInnerGraphOp [id A] ''
>op2 [id D] 'igo1'
> |4 [id E]
> |5 [id F]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
# Test nested inner-graph `Op`s
igo_2 = MyInnerGraphOp([r3, r4], [out])
r5 = MyVariable("5")
out_2 = igo_2(r5)
output_str = debugprint(out_2, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
|5 [id B]
Inner graphs:
MyInnerGraphOp [id A] ''
>MyInnerGraphOp [id C] ''
> |3 [id D]
> |4 [id E]
MyInnerGraphOp [id C] ''
>op2 [id F] 'igo1'
> |4 [id G]
> |5 [id H]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论