提交 724561b2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make aesara.printing.debugprint handle arbitrary HasInnerGraph Ops

上级 e3886d2b
差异被折叠。
...@@ -11,6 +11,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad ...@@ -11,6 +11,7 @@ from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.printing import debugprint
from aesara.tensor.basic import as_tensor from aesara.tensor.basic import as_tensor
from aesara.tensor.basic_opt import ShapeOptimizer from aesara.tensor.basic_opt import ShapeOptimizer
from aesara.tensor.math import dot, exp from aesara.tensor.math import dot, exp
...@@ -433,3 +434,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -433,3 +434,31 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
f = op(y) f = op(y)
grad_f = grad(f, y) grad_f = grad(f, y)
assert grad_f.tag.test_value is not None assert grad_f.tag.test_value is not None
def test_debugprint():
x, y, z = matrices("xyz")
e = x + y * z
op = OpFromGraph([x, y, z], [e])
out = op(x, y, z)
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A] ''
|x [id B]
|y [id C]
|z [id D]
Inner graphs:
OpFromGraph{inline=False} [id A] ''
>Elemwise{add,no_inplace} [id E] ''
> |x [id F]
> |Elemwise{mul,no_inplace} [id G] ''
> |y [id H]
> |z [id I]
"""
for truth, out in zip(exp_res.split("\n"), lines):
assert truth.strip() == out.strip()
import numpy as np import numpy as np
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -108,3 +108,33 @@ op_y = MyOp("OpY", x=1) ...@@ -108,3 +108,33 @@ op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1) op_z = MyOp("OpZ", x=1)
op_cast_type2 = MyOpCastType2("OpCastType2") op_cast_type2 = MyOpCastType2("OpCastType2")
class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [inputs[0].type()]
return Apply(self, list(inputs), outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
@property
def fn(self):
raise NotImplementedError("No Python implementation available.")
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
...@@ -51,7 +51,7 @@ def test_scan_debugprint1(): ...@@ -51,7 +51,7 @@ def test_scan_debugprint1():
| |ScalarConstant{1} [id U] | |ScalarConstant{1} [id U]
|ScalarConstant{-1} [id V] |ScalarConstant{-1} [id V]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id C] '' for{cpu,scan_fn} [id C] ''
>Elemwise{mul,no_inplace} [id W] '' >Elemwise{mul,no_inplace} [id W] ''
...@@ -111,7 +111,7 @@ def test_scan_debugprint2(): ...@@ -111,7 +111,7 @@ def test_scan_debugprint2():
|Elemwise{scalar_minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] ''
|x [id W] |x [id W]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] ''
>Elemwise{mul,no_inplace} [id X] '' >Elemwise{mul,no_inplace} [id X] ''
...@@ -124,7 +124,6 @@ def test_scan_debugprint2(): ...@@ -124,7 +124,6 @@ def test_scan_debugprint2():
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@aesara.config.change_flags(optimizer_verbose=True)
def test_scan_debugprint3(): def test_scan_debugprint3():
coefficients = dvector("coefficients") coefficients = dvector("coefficients")
max_coefficients_supported = 10 max_coefficients_supported = 10
...@@ -192,7 +191,7 @@ def test_scan_debugprint3(): ...@@ -192,7 +191,7 @@ def test_scan_debugprint3():
|A [id W] |A [id W]
|k [id X] |k [id X]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] ''
>Elemwise{mul,no_inplace} [id Y] '' >Elemwise{mul,no_inplace} [id Y] ''
...@@ -293,7 +292,7 @@ def test_scan_debugprint4(): ...@@ -293,7 +292,7 @@ def test_scan_debugprint4():
|for{cpu,scan_fn}.1 [id C] '' |for{cpu,scan_fn}.1 [id C] ''
|ScalarConstant{2} [id BA] |ScalarConstant{2} [id BA]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn}.0 [id C] '' for{cpu,scan_fn}.0 [id C] ''
>Elemwise{add,no_inplace} [id BB] '' >Elemwise{add,no_inplace} [id BB] ''
...@@ -412,7 +411,7 @@ def test_scan_debugprint5(): ...@@ -412,7 +411,7 @@ def test_scan_debugprint5():
| |A [id P] | |A [id P]
|ScalarConstant{-1} [id CL] |ScalarConstant{-1} [id CL]
Inner graphs of the scan ops: Inner graphs:
for{cpu,grad_of_scan_fn}.1 [id B] '' for{cpu,grad_of_scan_fn}.1 [id B] ''
>Elemwise{add,no_inplace} [id CM] '' >Elemwise{add,no_inplace} [id CM] ''
......
...@@ -2015,7 +2015,7 @@ class TestLocalUselessElemwiseComparison: ...@@ -2015,7 +2015,7 @@ class TestLocalUselessElemwiseComparison:
| |Subtensor{int64} [id C] '' | |Subtensor{int64} [id C] ''
|Y [id K] |Y [id K]
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] ''
>Sum{acc_dtype=float64} [id L] '' >Sum{acc_dtype=float64} [id L] ''
...@@ -2050,7 +2050,7 @@ class TestLocalUselessElemwiseComparison: ...@@ -2050,7 +2050,7 @@ class TestLocalUselessElemwiseComparison:
| |Shape_i{0} [id C] <TensorType(int64, scalar)> '' 0 | |Shape_i{0} [id C] <TensorType(int64, scalar)> '' 0
|Y [id M] <TensorType(float64, vector)> |Y [id M] <TensorType(float64, vector)>
Inner graphs of the scan ops: Inner graphs:
for{cpu,scan_fn} [id B] <TensorType(float64, vector)> '' for{cpu,scan_fn} [id B] <TensorType(float64, vector)> ''
>Sum{acc_dtype=float64} [id N] <TensorType(float64, scalar)> '' >Sum{acc_dtype=float64} [id N] <TensorType(float64, scalar)> ''
......
...@@ -15,6 +15,7 @@ from aesara.printing import ( ...@@ -15,6 +15,7 @@ from aesara.printing import (
pydotprint, pydotprint,
) )
from aesara.tensor.type import dmatrix, dvector, matrix from aesara.tensor.type import dmatrix, dvector, matrix
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
@pytest.mark.skipif(not pydot_imported, reason="pydot not available") @pytest.mark.skipif(not pydot_imported, reason="pydot not available")
...@@ -109,6 +110,9 @@ def test_min_informative_str(): ...@@ -109,6 +110,9 @@ def test_min_informative_str():
def test_debugprint(): def test_debugprint():
with pytest.raises(TypeError):
debugprint("blah")
A = matrix(name="A") A = matrix(name="A")
B = matrix(name="B") B = matrix(name="B")
C = A + B C = A + B
...@@ -277,3 +281,66 @@ def test_pprint(): ...@@ -277,3 +281,66 @@ def test_pprint():
x = dvector() x = dvector()
y = x[1] y = x[1]
assert pp(y) == "<TensorType(float64, vector)>[ScalarConstant{1}]" assert pp(y) == "<TensorType(float64, vector)>[ScalarConstant{1}]"
def test_debugprint_inner_graph():
r1, r2 = MyVariable("1"), MyVariable("2")
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("4")
igo_in_2 = MyVariable("5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
r3, r4 = MyVariable("3"), MyVariable("4")
out = igo(r3, r4)
output_str = debugprint(out, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
|3 [id B]
|4 [id C]
Inner graphs:
MyInnerGraphOp [id A] ''
>op2 [id D] 'igo1'
> |4 [id E]
> |5 [id F]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
# Test nested inner-graph `Op`s
igo_2 = MyInnerGraphOp([r3, r4], [out])
r5 = MyVariable("5")
out_2 = igo_2(r5)
output_str = debugprint(out_2, file="str")
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] ''
|5 [id B]
Inner graphs:
MyInnerGraphOp [id A] ''
>MyInnerGraphOp [id C] ''
> |3 [id D]
> |4 [id E]
MyInnerGraphOp [id C] ''
>op2 [id F] 'igo1'
> |4 [id G]
> |5 [id H]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
assert exp_line.strip() == res_line.strip()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论