提交 724561b2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make aesara.printing.debugprint handle arbitrary HasInnerGraph Ops

上级 e3886d2b
...@@ -21,7 +21,7 @@ from aesara.compile.profiling import ProfileStats ...@@ -21,7 +21,7 @@ from aesara.compile.profiling import ProfileStats
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort from aesara.graph.basic import Apply, Constant, Variable, graph_inputs, io_toposort
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, StorageMapType from aesara.graph.op import HasInnerGraph, Op, StorageMapType
from aesara.graph.utils import Scratchpad from aesara.graph.utils import Scratchpad
...@@ -149,29 +149,32 @@ def debugprint( ...@@ -149,29 +149,32 @@ def debugprint(
A string representing the printed graph, if `file` is a string, else `file`. A string representing the printed graph, if `file` is a string, else `file`.
""" """
from aesara.scan.op import Scan
if not isinstance(depth, int): if not isinstance(depth, int):
raise Exception("depth parameter must be an int") raise Exception("depth parameter must be an int")
if file == "str": if file == "str":
_file = StringIO() _file = StringIO()
elif file is None: elif file is None:
_file = sys.stdout _file = sys.stdout
else: else:
_file = file _file = file
if done is None: if done is None:
done = dict() done = dict()
if used_ids is None: if used_ids is None:
used_ids = dict() used_ids = dict()
used_ids = dict()
results_to_print = [] results_to_print = []
profile_list = [] profile_list = []
order = [] # Toposort order = [] # Toposort
smap = [] # storage_map smap = [] # storage_map
if isinstance(obj, (list, tuple, set)): if isinstance(obj, (list, tuple, set)):
lobj = obj lobj = obj
else: else:
lobj = [obj] lobj = [obj]
for obj in lobj: for obj in lobj:
if isinstance(obj, Variable): if isinstance(obj, Variable):
results_to_print.append(obj) results_to_print.append(obj)
...@@ -206,9 +209,9 @@ def debugprint( ...@@ -206,9 +209,9 @@ def debugprint(
smap.append(None) smap.append(None)
order.append(None) order.append(None)
else: else:
raise TypeError("debugprint cannot print an object of this type", obj) raise TypeError(f"debugprint cannot print an object type {type(obj)}")
scan_ops = [] inner_graph_ops = []
if any([p for p in profile_list if p is not None and p.fct_callcount > 0]): if any([p for p in profile_list if p is not None and p.fct_callcount > 0]):
print( print(
""" """
...@@ -232,9 +235,8 @@ N.B.: ...@@ -232,9 +235,8 @@ N.B.:
) )
for r, p, s, o in zip(results_to_print, profile_list, smap, order): for r, p, s, o in zip(results_to_print, profile_list, smap, order):
# Add the parent scan op to the list as well if hasattr(r.owner, "op") and isinstance(r.owner.op, HasInnerGraph):
if hasattr(r.owner, "op") and isinstance(r.owner.op, Scan): inner_graph_ops.append(r)
scan_ops.append(r)
_debugprint( _debugprint(
r, r,
...@@ -244,36 +246,48 @@ N.B.: ...@@ -244,36 +246,48 @@ N.B.:
file=_file, file=_file,
order=o, order=o,
ids=ids, ids=ids,
scan_ops=scan_ops, inner_graph_ops=inner_graph_ops,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
profile=p, profile=p,
smap=s, smap=s,
used_ids=used_ids, used_ids=used_ids,
) )
if len(scan_ops) > 0: if len(inner_graph_ops) > 0:
print("", file=_file) print("", file=_file)
new_prefix = " >" new_prefix = " >"
new_prefix_child = " >" new_prefix_child = " >"
print("Inner graphs of the scan ops:", file=_file) print("Inner graphs:", file=_file)
for s in inner_graph_ops:
for s in scan_ops: # This is a work-around to maintain backward compatibility
# prepare a dict which maps the scan op's inner inputs # (e.g. to only print inner graphs that have been compiled through
# to its outer inputs. # a call to `Op.prepare_node`)
if hasattr(s.owner.op, "_fn"): inner_fn = getattr(s.owner.op, "_fn", None)
if inner_fn:
# If the op was compiled, print the optimized version. # If the op was compiled, print the optimized version.
inner_inputs = s.owner.op.fn.maker.fgraph.inputs inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else: else:
inner_inputs = s.owner.op.inputs inner_inputs = s.owner.op.inner_inputs
inner_outputs = s.owner.op.inner_outputs
outer_inputs = s.owner.inputs outer_inputs = s.owner.inputs
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o] if hasattr(s.owner.op, "get_oinp_iinp_iout_oout_mappings"):
for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[ inner_to_outer_inputs = {
"outer_inp_from_inner_inp" inner_inputs[i]: outer_inputs[o]
].items() for i, o in s.owner.op.get_oinp_iinp_iout_oout_mappings()[
} "outer_inp_from_inner_inp"
].items()
}
else:
inner_to_outer_inputs = None
print("", file=_file) print("", file=_file)
_debugprint( _debugprint(
s, s,
depth=depth, depth=depth,
...@@ -281,21 +295,16 @@ N.B.: ...@@ -281,21 +295,16 @@ N.B.:
print_type=print_type, print_type=print_type,
file=_file, file=_file,
ids=ids, ids=ids,
scan_ops=scan_ops, inner_graph_ops=inner_graph_ops,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
scan_inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
) )
if hasattr(s.owner.op, "_fn"):
# If the op was compiled, print the optimized version.
outputs = s.owner.op.fn.maker.fgraph.outputs
else:
outputs = s.owner.op.outputs
for idx, i in enumerate(outputs):
if hasattr(i, "owner") and hasattr(i.owner, "op"): for idx, i in enumerate(inner_outputs):
if isinstance(i.owner.op, Scan):
scan_ops.append(i) if isinstance(getattr(i.owner, "op", None), HasInnerGraph):
inner_graph_ops.append(i)
_debugprint( _debugprint(
r=i, r=i,
...@@ -307,8 +316,8 @@ N.B.: ...@@ -307,8 +316,8 @@ N.B.:
ids=ids, ids=ids,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, prefix_child=new_prefix_child,
scan_ops=scan_ops, inner_graph_ops=inner_graph_ops,
scan_inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
) )
...@@ -333,9 +342,9 @@ def _debugprint( ...@@ -333,9 +342,9 @@ def _debugprint(
ids: str = "CHAR", ids: str = "CHAR",
stop_on_name: bool = False, stop_on_name: bool = False,
prefix_child: Optional[str] = None, prefix_child: Optional[str] = None,
scan_ops: Optional[List[Variable]] = None, inner_graph_ops: Optional[List[Variable]] = None,
profile: Optional[ProfileStats] = None, profile: Optional[ProfileStats] = None,
scan_inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None, inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
smap: Optional[StorageMapType] = None, smap: Optional[StorageMapType] = None,
used_ids: Optional[Dict[Variable, str]] = None, used_ids: Optional[Dict[Variable, str]] = None,
) -> IOBase: ) -> IOBase:
...@@ -373,11 +382,10 @@ def _debugprint( ...@@ -373,11 +382,10 @@ def _debugprint(
stop_on_name stop_on_name
When ``True``, if a node in the graph has a name, we don't print anything When ``True``, if a node in the graph has a name, we don't print anything
below it. below it.
scan_ops inner_graph_ops
`Scan` `Op`\s in the graph will be added inside this list for later A list of `Op`\s with inner graphs.
printing purposes. inner_to_outer_inputs
scan_inner_to_outer_inputs A dictionary mapping an `Op`'s inner-inputs to its outer-inputs.
A dictionary mapping a `Scan` `Op`'s inner-inputs to its outer-inputs.
smap smap
``None`` or the ``storage_map`` when printing an Aesara function. ``None`` or the ``storage_map`` when printing an Aesara function.
used_ids used_ids
...@@ -394,8 +402,8 @@ def _debugprint( ...@@ -394,8 +402,8 @@ def _debugprint(
if done is None: if done is None:
done = dict() done = dict()
if scan_ops is None: if inner_graph_ops is None:
scan_ops = [] inner_graph_ops = []
if print_type: if print_type:
type_str = f" <{r.type}>" type_str = f" <{r.type}>"
...@@ -518,10 +526,8 @@ def _debugprint( ...@@ -518,10 +526,8 @@ def _debugprint(
new_prefix_child = prefix_child + " " new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"): if hasattr(i, "owner") and hasattr(i.owner, "op"):
from aesara.scan.op import Scan if isinstance(i.owner.op, HasInnerGraph):
inner_graph_ops.append(i)
if isinstance(i.owner.op, Scan):
scan_ops.append(i)
_debugprint( _debugprint(
i, i,
...@@ -534,17 +540,17 @@ def _debugprint( ...@@ -534,17 +540,17 @@ def _debugprint(
ids=ids, ids=ids,
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, prefix_child=new_prefix_child,
scan_ops=scan_ops, inner_graph_ops=inner_graph_ops,
profile=profile, profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
smap=smap, smap=smap,
used_ids=used_ids, used_ids=used_ids,
) )
else: else:
if scan_inner_to_outer_inputs is not None and r in scan_inner_to_outer_inputs: if inner_to_outer_inputs is not None and r in inner_to_outer_inputs:
id_str = get_id_str(r) id_str = get_id_str(r)
outer_r = scan_inner_to_outer_inputs[r] outer_r = inner_to_outer_inputs[r]
if hasattr(outer_r.owner, "op"): if hasattr(outer_r.owner, "op"):
outer_id_str = get_id_str(outer_r.owner) outer_id_str = get_id_str(outer_r.owner)
...@@ -665,8 +671,8 @@ class OperatorPrinter: ...@@ -665,8 +671,8 @@ class OperatorPrinter:
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError( raise TypeError(
"operator %s cannot represent a variable that is " f"operator {self.operator} cannot represent a variable that is "
"not the result of an operation" % self.operator "not the result of an operation"
) )
# Precedence seems to be buggy, see #249 # Precedence seems to be buggy, see #249
...@@ -720,8 +726,8 @@ class PatternPrinter: ...@@ -720,8 +726,8 @@ class PatternPrinter:
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError( raise TypeError(
"Patterns %s cannot represent a variable that is " f"Patterns {self.patterns} cannot represent a variable that is "
"not the result of an operation" % self.patterns "not the result of an operation"
) )
idx = node.outputs.index(output) idx = node.outputs.index(output)
pattern, precedences = self.patterns[idx] pattern, precedences = self.patterns[idx]
...@@ -760,8 +766,8 @@ class FunctionPrinter: ...@@ -760,8 +766,8 @@ class FunctionPrinter:
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError( raise TypeError(
"function %s cannot represent a variable that is " f"function {self.names} cannot represent a variable that is "
"not the result of an operation" % self.names "not the result of an operation"
) )
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
...@@ -1107,10 +1113,9 @@ def pydotprint( ...@@ -1107,10 +1113,9 @@ def pydotprint(
fgraph = fct fgraph = fct
if not pydot_imported: if not pydot_imported:
raise RuntimeError( raise RuntimeError(
"Failed to import pydot. You must install graphviz" "Failed to import pydot. You must install graphviz "
" and either pydot or pydot-ng for " "and either pydot or pydot-ng for "
"`pydotprint` to work.", f"`pydotprint` to work:\n {pydot_imported_msg}",
pydot_imported_msg,
) )
g = pd.Dot() g = pd.Dot()
......
...@@ -11,6 +11,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad ...@@ -11,6 +11,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer from aesara.tensor.basic_opt import ShapeOptimizer
from aesara.tensor.math import dot, exp from aesara.tensor.math import dot, exp
...@@ -433,3 +434,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -433,3 +434,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
f = op(y) f = op(y)
grad_f = grad(f, y) grad_f = grad(f, y)
assert grad_f.tag.test_value is not None assert grad_f.tag.test_value is not None
def test_debugprint():
x, y, z = matrices("xyz")
e = x + y * z
op = OpFromGraph([x, y, z], [e])
out = op(x, y, z)
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A] ''
|x [id B]
|y [id C]
|z [id D]
Inner graphs:
OpFromGraph{inline=False} [id A] ''
>Elemwise{add,no_inplace} [id E] ''
> |x [id F]
> |Elemwise{mul,no_inplace} [id G] ''
> |y [id H]
> |z [id I]
"""
for truth, out in zip(exp_res.split("\n"), lines):
assert truth.strip() == out.strip()
import numpy as np import numpy as np
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -108,3 +108,33 @@ op_y = MyOp("OpY", x=1) ...@@ -108,3 +108,33 @@ op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1) op_z = MyOp("OpZ", x=1)
op_cast_type2 = MyOpCastType2("OpCastType2") op_cast_type2 = MyOpCastType2("OpCastType2")
class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [inputs[0].type()]
return Apply(self, list(inputs), outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
@property
def fn(self):
raise NotImplementedError("No Python implementation available.")
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
...@@ -51,7 +51,7 @@ def test_scan_debugprint1(): ...@@ -51,7 +51,7 @@ def test_scan_debugprint1():
| |ScalarConstant{1} [id U] | |ScalarConstant{1} [id U]
|ScalarConstant{-1} [id V] |ScalarConstant{-1} [id V]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id C] '' for{cpu,scan_fn} [id C] ''
>Elemwise{mul,no_inplace} [id W] '' >Elemwise{mul,no_inplace} [id W] ''
...@@ -111,7 +111,7 @@ def test_scan_debugprint2(): ...@@ -111,7 +111,7 @@ def test_scan_debugprint2():
|Elemwise{scalar_minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] ''
|x [id W] |x [id W]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] ''
>Elemwise{mul,no_inplace} [id X] '' >Elemwise{mul,no_inplace} [id X] ''
...@@ -124,7 +124,6 @@ def test_scan_debugprint2(): ...@@ -124,7 +124,6 @@ def test_scan_debugprint2():
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@aesara.config.change_flags(optimizer_verbose=True)
def test_scan_debugprint3(): def test_scan_debugprint3():
coefficients = dvector("coefficients") coefficients = dvector("coefficients")
max_coefficients_supported = 10 max_coefficients_supported = 10
...@@ -192,7 +191,7 @@ def test_scan_debugprint3(): ...@@ -192,7 +191,7 @@ def test_scan_debugprint3():
|A [id W] |A [id W]
|k [id X] |k [id X]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] ''
>Elemwise{mul,no_inplace} [id Y] '' >Elemwise{mul,no_inplace} [id Y] ''
...@@ -293,7 +292,7 @@ def test_scan_debugprint4(): ...@@ -293,7 +292,7 @@ def test_scan_debugprint4():
|for{cpu,scan_fn}.1 [id C] '' |for{cpu,scan_fn}.1 [id C] ''
|ScalarConstant{2} [id BA] |ScalarConstant{2} [id BA]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn}.0 [id C] '' for{cpu,scan_fn}.0 [id C] ''
>Elemwise{add,no_inplace} [id BB] '' >Elemwise{add,no_inplace} [id BB] ''
...@@ -412,7 +411,7 @@ def test_scan_debugprint5(): ...@@ -412,7 +411,7 @@ def test_scan_debugprint5():
| |A [id P] | |A [id P]
|ScalarConstant{-1} [id CL] |ScalarConstant{-1} [id CL]
Inner graphs of the scan ops: Inner graphs:
for{cpu,grad_of_scan_fn}.1 [id B] '' for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id CM] '' >Elemwise{add,no_inplace} [id CM] ''
......
...@@ -2015,7 +2015,7 @@ class TestLocalUselessElemwiseComparison: ...@@ -2015,7 +2015,7 @@ class TestLocalUselessElemwiseComparison:
| |Subtensor{int64} [id C] '' | |Subtensor{int64} [id C] ''
|Y [id K] |Y [id K]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] ''
>Sum{acc_dtype=float64} [id L] '' >Sum{acc_dtype=float64} [id L] ''
...@@ -2050,7 +2050,7 @@ class TestLocalUselessElemwiseComparison: ...@@ -2050,7 +2050,7 @@ class TestLocalUselessElemwiseComparison:
| |Shape_i{0} [id C] <TensorType(int64, scalar)> '' 0 | |Shape_i{0} [id C] <TensorType(int64, scalar)> '' 0
|Y [id M] <TensorType(float64, vector)> |Y [id M] <TensorType(float64, vector)>
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] <TensorType(float64, vector)> '' for{cpu,scan_fn} [id B] <TensorType(float64, vector)> ''
>Sum{acc_dtype=float64} [id N] <TensorType(float64, scalar)> '' >Sum{acc_dtype=float64} [id N] <TensorType(float64, scalar)> ''
......
...@@ -15,6 +15,7 @@ from aesara.printing import ( ...@@ -15,6 +15,7 @@ from aesara.printing import (
pydotprint, pydotprint,
) )
from aesara.tensor.type import dmatrix, dvector, matrix from aesara.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
@pytest.mark.skipif(not pydot_imported, reason="pydot not available") @pytest.mark.skipif(not pydot_imported, reason="pydot not available")
...@@ -109,6 +110,9 @@ def test_min_informative_str(): ...@@ -109,6 +110,9 @@ def test_min_informative_str():
def test_debugprint(): def test_debugprint():
with pytest.raises(TypeError):
debugprint("blah")
A = matrix(name="A") A = matrix(name="A")
B = matrix(name="B") B = matrix(name="B")
C = A + B C = A + B
...@@ -277,3 +281,66 @@ def test_pprint(): ...@@ -277,3 +281,66 @@ def test_pprint():
x = dvector() x = dvector()
y = x[1] y = x[1]
assert pp(y) == "<TensorType(float64, vector)>[ScalarConstant{1}]" assert pp(y) == "<TensorType(float64, vector)>[ScalarConstant{1}]"
def test_debugprint_inner_graph():
r1, r2 = MyVariable("1"), MyVariable("2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("4")
igo_in_2 = MyVariable("5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
r3, r4 = MyVariable("3"), MyVariable("4")
out = igo(r3, r4)
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
|3 [id B]
|4 [id C]
Inner graphs:
MyInnerGraphOp [id A] ''
>op2 [id D] 'igo1'
> |4 [id E]
> |5 [id F]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
# Test nested inner-graph `Op`s
igo_2 = MyInnerGraphOp([r3, r4], [out])
r5 = MyVariable("5")
out_2 = igo_2(r5)
output_str = debugprint(out_2, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
|5 [id B]
Inner graphs:
MyInnerGraphOp [id A] ''
>MyInnerGraphOp [id C] ''
> |3 [id D]
> |4 [id E]
MyInnerGraphOp [id C] ''
>op2 [id F] 'igo1'
> |4 [id G]
> |5 [id H]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论