提交 5521d824 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Render inner-graphs of Composite Ops in debugprint

上级 79d98f1e
...@@ -312,7 +312,11 @@ N.B.: ...@@ -312,7 +312,11 @@ N.B.:
): ):
if hasattr(var.owner, "op"): if hasattr(var.owner, "op"):
if isinstance(var.owner.op, HasInnerGraph) and var not in inner_graph_vars: if (
isinstance(var.owner.op, HasInnerGraph)
or hasattr(var.owner.op, "scalar_op")
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
) and var not in inner_graph_vars:
inner_graph_vars.append(var) inner_graph_vars.append(var)
if print_op_info: if print_op_info:
op_information.update(op_debug_information(var.owner.op, var.owner)) op_information.update(op_debug_information(var.owner.op, var.owner))
...@@ -355,8 +359,12 @@ N.B.: ...@@ -355,8 +359,12 @@ N.B.:
inner_inputs = inner_fn.maker.fgraph.inputs inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs inner_outputs = inner_fn.maker.fgraph.outputs
else: else:
inner_inputs = ig_var.owner.op.inner_inputs if hasattr(ig_var.owner.op, "scalar_op"):
inner_outputs = ig_var.owner.op.inner_outputs inner_inputs = ig_var.owner.op.scalar_op.inner_inputs
inner_outputs = ig_var.owner.op.scalar_op.inner_outputs
else:
inner_inputs = ig_var.owner.op.inner_inputs
inner_outputs = ig_var.owner.op.inner_outputs
outer_inputs = ig_var.owner.inputs outer_inputs = ig_var.owner.inputs
...@@ -422,8 +430,9 @@ N.B.: ...@@ -422,8 +430,9 @@ N.B.:
if ( if (
isinstance(getattr(out.owner, "op", None), HasInnerGraph) isinstance(getattr(out.owner, "op", None), HasInnerGraph)
and out not in inner_graph_vars or hasattr(getattr(out.owner, "op", None), "scalar_op")
): and isinstance(out.owner.op.scalar_op, HasInnerGraph)
) and out not in inner_graph_vars:
inner_graph_vars.append(out) inner_graph_vars.append(out)
_debugprint( _debugprint(
...@@ -664,8 +673,9 @@ def _debugprint( ...@@ -664,8 +673,9 @@ def _debugprint(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"): if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if ( if (
isinstance(in_var.owner.op, HasInnerGraph) isinstance(in_var.owner.op, HasInnerGraph)
and in_var not in inner_graph_ops or hasattr(in_var.owner.op, "scalar_op")
): and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
) and in_var not in inner_graph_ops:
inner_graph_ops.append(in_var) inner_graph_ops.append(in_var)
_debugprint( _debugprint(
......
...@@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph): ...@@ -4000,7 +4000,8 @@ class Composite(ScalarOp, HasInnerGraph):
init_param: Tuple[str, ...] = ("inputs", "outputs") init_param: Tuple[str, ...] = ("inputs", "outputs")
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs, name="Composite"):
self.name = name
# We need to clone the graph as sometimes its nodes already # We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite # contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph. # to be pickable, we can't have reference to fgraph.
...@@ -4106,30 +4107,6 @@ class Composite(ScalarOp, HasInnerGraph): ...@@ -4106,30 +4107,6 @@ class Composite(ScalarOp, HasInnerGraph):
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn return self._py_perform_fn
@property
def name(self):
if hasattr(self, "_name"):
return self._name
# TODO FIXME: Just implement pretty printing for the `Op`; don't do
# this redundant, outside work in the `Op` itself.
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
rval = f"Composite{{{outputs_str}}}"
self._name = rval
return self._name
@name.setter
def name(self, name):
self._name = name
@property @property
def fgraph(self): def fgraph(self):
if hasattr(self, "_fgraph"): if hasattr(self, "_fgraph"):
......
...@@ -183,12 +183,7 @@ class TestComposite: ...@@ -183,12 +183,7 @@ class TestComposite:
make_function(DualLinker().accept(g)) make_function(DualLinker().accept(g))
assert str(g) == ( assert str(g) == (
"FunctionGraph(*1 -> Composite{((i0 + i1) + i2)," "FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
" (i0 + (i1 * i2)), (i0 / i1), "
"(i0 // 5), "
"(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),"
" (i0 % 3)}(x, y, z), "
"*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
) )
def test_non_scalar_error(self): def test_non_scalar_error(self):
......
...@@ -604,31 +604,40 @@ def test_debugprint_compiled_fn(): ...@@ -604,31 +604,40 @@ def test_debugprint_compiled_fn():
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN") out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0) expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
|TensorConstant{20000} [id B] (n_steps) |TensorConstant{20000} [id B] (n_steps)
|TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
|IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0) |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
| |AllocEmpty{dtype='int64'} [id E] 0 | |AllocEmpty{dtype='int64'} [id E] 0
| | |TensorConstant{20000} [id B] | | |TensorConstant{20000} [id B]
| |TensorConstant{(1,) of 0} [id F] | |TensorConstant{(1,) of 0} [id F]
| |ScalarConstant{1} [id G] | |ScalarConstant{1} [id G]
|<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0) |<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
Inner graphs: Inner graphs:
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0) forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0) >Elemwise{Composite} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J] > |TensorConstant{0} [id J]
> |Subtensor{int64, int64, uint8} [id K] > |Subtensor{int64, int64, uint8} [id K]
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0) > | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |ScalarFromTensor [id M] > | |ScalarFromTensor [id M]
> | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0) > | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | |ScalarFromTensor [id O] > | |ScalarFromTensor [id O]
> | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0) > | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | |ScalarConstant{0} [id Q] > | |ScalarConstant{0} [id Q]
> |TensorConstant{1} [id R] > |TensorConstant{1} [id R]
Elemwise{Composite} [id I]
>Switch [id S]
> |LT [id T]
> | |<int64> [id U]
> | |<float64> [id V]
> |<int64> [id W]
> |<int64> [id U]
""" """
output_str = debugprint(out, file="str", print_op_info=True) output_str = debugprint(out, file="str", print_op_info=True)
print(output_str)
lines = output_str.split("\n") lines = output_str.split("\n")
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
......
...@@ -16,7 +16,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -16,7 +16,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.printing import pprint from pytensor.printing import debugprint, pprint
from pytensor.raise_op import Assert, CheckAndRaise from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
...@@ -1105,7 +1105,7 @@ class TestLocalMergeSwitchSameCond: ...@@ -1105,7 +1105,7 @@ class TestLocalMergeSwitchSameCond:
s2 = at.switch(c, x, y) s2 = at.switch(c, x, y)
g = rewrite(FunctionGraph(mats, [op(s1, s2)])) g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", "op",
...@@ -1122,7 +1122,7 @@ class TestLocalMergeSwitchSameCond: ...@@ -1122,7 +1122,7 @@ class TestLocalMergeSwitchSameCond:
s1 = at.switch(c, a, b) s1 = at.switch(c, a, b)
s2 = at.switch(c, x, y) s2 = at.switch(c, x, y)
g = rewrite(FunctionGraph(mats, [op(s1, s2)])) g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
@pytest.mark.parametrize("op", [add, mul]) @pytest.mark.parametrize("op", [add, mul])
def test_elemwise_multi_inputs(self, op): def test_elemwise_multi_inputs(self, op):
...@@ -1134,7 +1134,7 @@ class TestLocalMergeSwitchSameCond: ...@@ -1134,7 +1134,7 @@ class TestLocalMergeSwitchSameCond:
u, v = matrices("uv") u, v = matrices("uv")
s3 = at.switch(c, u, v) s3 = at.switch(c, u, v)
g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
class TestLocalOptAlloc: class TestLocalOptAlloc:
......
...@@ -28,6 +28,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -28,6 +28,7 @@ from pytensor.graph.rewriting.basic import (
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.tensor import inplace from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, switch from pytensor.tensor.basic import Alloc, join, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
...@@ -2416,7 +2417,7 @@ class TestLocalMergeSwitchSameCond: ...@@ -2416,7 +2417,7 @@ class TestLocalMergeSwitchSameCond:
at_pow, at_pow,
): ):
g = rewrite(FunctionGraph(mats, [op(s1, s2)])) g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
# integer Ops # integer Ops
mats = imatrices("cabxy") mats = imatrices("cabxy")
c, a, b, x, y = mats c, a, b, x, y = mats
...@@ -2428,13 +2429,13 @@ class TestLocalMergeSwitchSameCond: ...@@ -2428,13 +2429,13 @@ class TestLocalMergeSwitchSameCond:
bitwise_xor, bitwise_xor,
): ):
g = rewrite(FunctionGraph(mats, [op(s1, s2)])) g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
# add/mul with more than two inputs # add/mul with more than two inputs
u, v = matrices("uv") u, v = matrices("uv")
s3 = at.switch(c, u, v) s3 = at.switch(c, u, v)
for op in (add, mul): for op in (add, mul):
g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count("Switch") == 1 assert debugprint(g, file="str").count("Switch") == 1
class TestLocalSumProd: class TestLocalSumProd:
......
...@@ -273,7 +273,7 @@ def test_debugprint(): ...@@ -273,7 +273,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite{(i2 + (i0 - i1))}} 4 Elemwise{Composite} 4
|InplaceDimShuffle{x,0} v={0: [0]} 3 |InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2 | |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1 | |AllocEmpty{dtype='float64'} 1
...@@ -285,6 +285,15 @@ def test_debugprint(): ...@@ -285,6 +285,15 @@ def test_debugprint():
| |TensorConstant{0.0} | |TensorConstant{0.0}
|D |D
|A |A
Inner graphs:
Elemwise{Composite}
>add
> |<float64>
> |sub
> |<float64>
> |<float64>
""" """
).lstrip() ).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论