提交 5521d824 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Render inner-graphs of Composite Ops in debugprint

上级 79d98f1e
......@@ -312,7 +312,11 @@ N.B.:
):
if hasattr(var.owner, "op"):
if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars:
if (
isinstance(var.owner.op, HasInnerGraph)
or hasattr(var.owner.op, "scalar_op")
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
) and var not in inner_graph_vars:
inner_graph_vars.append(var)
if print_op_info:
op_information.update(op_debug_information(var.owner.op, var.owner))
......@@ -354,6 +358,10 @@ N.B.:
# If the op was compiled, print the optimized version.
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
if hasattr(ig_var.owner.op, "scalar_op"):
inner_inputs = ig_var.owner.op.scalar_op.inner_inputs
inner_outputs = ig_var.owner.op.scalar_op.inner_outputs
else:
inner_inputs = ig_var.owner.op.inner_inputs
inner_outputs = ig_var.owner.op.inner_outputs
......@@ -422,8 +430,9 @@ N.B.:
if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and out not in inner_graph_vars
):
or hasattr(getattr(out.owner, "op", None), "scalar_op")
and isinstance(out.owner.op.scalar_op, HasInnerGraph)
) and out not in inner_graph_vars:
inner_graph_vars.append(out)
_debugprint(
......@@ -664,8 +673,9 @@ def _debugprint(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
isinstance(in_var.owner.op, HasInnerGraph)
and in_var not in inner_graph_ops
):
or hasattr(in_var.owner.op, "scalar_op")
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
) and in_var not in inner_graph_ops:
inner_graph_ops.append(in_var)
_debugprint(
......
......@@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph):
init_param: Tuple[str, ...] = ("inputs", "outputs")
def __init__(self, inputs, outputs):
def __init__(self, inputs, outputs, name="Composite"):
self.name = name
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
......@@ -4106,30 +4107,6 @@ class Composite(ScalarOp, HasInnerGraph):
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn
@property
def name(self):
if hasattr(self, "_name"):
return self._name
# TODO FIXME: Just implement pretty printing for the `Op`; don't do
# this redundant, outside work in the `Op` itself.
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self._name = rval
return self._name
@name.setter
def name(self, name):
self._name = name
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
......
......@@ -183,12 +183,7 @@ class TestComposite:
make_function(DualLinker().accept(g))
assert str(g) == (
"FunctionGraph(*1 -> Composite{((i0 + i1) + i2),"
" (i0 + (i1 * i2)), (i0 / i1), "
"(i0 // 5), "
"(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),"
" (i0 % 3)}(x, y, z), "
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
"FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
def test_non_scalar_error(self):
......
......@@ -616,7 +616,7 @@ def test_debugprint_compiled_fn():
Inner graphs:
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
>Elemwise{Composite} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J]
> |Subtensor{int64, int64, uint8} [id K]
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
......@@ -626,9 +626,18 @@ def test_debugprint_compiled_fn():
> | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | |ScalarConstant{0} [id Q]
> |TensorConstant{1} [id R]
Elemwise{Composite} [id I]
>Switch [id S]
> |LT [id T]
> | |<int64> [id U]
> | |<float64> [id V]
> |<int64> [id W]
> |<int64> [id U]
"""
output_str = debugprint(out, file="str", print_op_info=True)
print(output_str)
lines = output_str.split("\n")
for truth, out in zip(expected_output.split("\n"), lines):
......
......@@ -16,7 +16,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.printing import pprint
from pytensor.printing import debugprint, pprint
from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.tensor.basic import (
Alloc,
......@@ -1105,7 +1105,7 @@ class TestLocalMergeSwitchSameCond:
s2 = at.switch(c, x, y)
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
assert debugprint(g, file="str").count("Switch") == 1
@pytest.mark.parametrize(
"op",
......@@ -1122,7 +1122,7 @@ class TestLocalMergeSwitchSameCond:
s1 = at.switch(c, a, b)
s2 = at.switch(c, x, y)
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
assert debugprint(g, file="str").count("Switch") == 1
@pytest.mark.parametrize("op", [add, mul])
def test_elemwise_multi_inputs(self, op):
......@@ -1134,7 +1134,7 @@ class TestLocalMergeSwitchSameCond:
u, v = matrices("uv")
s3 = at.switch(c, u, v)
g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1
assert debugprint(g, file="str").count("Switch") == 1
class TestLocalOptAlloc:
......
......@@ -28,6 +28,7 @@ from pytensor.graph.rewriting.basic import (
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, switch
from pytensor.tensor.blas import Dot22, Gemv
......@@ -2416,7 +2417,7 @@ class TestLocalMergeSwitchSameCond:
at_pow,
):
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
assert debugprint(g, file="str").count("Switch") == 1
# integer Ops
mats = imatrices("cabxy")
c, a, b, x, y = mats
......@@ -2428,13 +2429,13 @@ class TestLocalMergeSwitchSameCond:
bitwise_xor,
):
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1
assert debugprint(g, file="str").count("Switch") == 1
# add/mul with more than two inputs
u, v = matrices("uv")
s3 = at.switch(c, u, v)
for op in (add, mul):
g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1
assert debugprint(g, file="str").count("Switch") == 1
class TestLocalSumProd:
......
......@@ -273,7 +273,7 @@ def test_debugprint():
s = s.getvalue()
exp_res = dedent(
r"""
Elemwise{Composite{(i2 + (i0 - i1))}} 4
Elemwise{Composite} 4
|InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1
......@@ -285,6 +285,15 @@ def test_debugprint():
| |TensorConstant{0.0}
|D
|A
Inner graphs:
Elemwise{Composite}
>add
> |<float64>
> |sub
> |<float64>
> |<float64>
"""
).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论