提交 cc70b958 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Combine MyInnerGraphOp implementations and use FunctionGraphs and nominal variables

上级 3b0e071e
...@@ -26,7 +26,7 @@ from aesara.graph.basic import ( ...@@ -26,7 +26,7 @@ from aesara.graph.basic import (
vars_between, vars_between,
walk, walk,
) )
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.tensor.math import max_and_argmax from aesara.tensor.math import max_and_argmax
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -41,6 +41,7 @@ from aesara.tensor.type import ( ...@@ -41,6 +41,7 @@ from aesara.tensor.type import (
from aesara.tensor.type_other import NoneConst from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyInnerGraphOp
class MyType(Type): class MyType(Type):
...@@ -85,36 +86,6 @@ class MyOp(Op): ...@@ -85,36 +86,6 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
return Apply(self, list(inputs), outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
@property
def fn(self):
raise NotImplementedError("No Python implementation available.")
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
class X: class X:
def leaf_formatter(self, leaf): def leaf_formatter(self, leaf):
return str(leaf.type) return str(leaf.type)
...@@ -550,7 +521,8 @@ def test_get_var_by_name(): ...@@ -550,7 +521,8 @@ def test_get_var_by_name():
(res,) = get_var_by_name([o1, o2], "igo1") (res,) = get_var_by_name([o1, o2], "igo1")
assert res == igo_out_1 exp_res = igo.fgraph.outputs[0]
assert res == exp_res
class TestCloneReplace: class TestCloneReplace:
......
import numpy as np import numpy as np
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, NominalVariable, Variable, clone_replace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -21,6 +22,9 @@ class MyType(Type): ...@@ -21,6 +22,9 @@ class MyType(Type):
def __hash__(self): def __hash__(self):
return hash(MyType) return hash(MyType)
def __repr__(self):
return "MyType()"
class MyType2(Type): class MyType2(Type):
def filter(self, data): def filter(self, data):
...@@ -32,6 +36,9 @@ class MyType2(Type): ...@@ -32,6 +36,9 @@ class MyType2(Type):
def __hash__(self): def __hash__(self):
return hash(MyType) return hash(MyType)
def __repr__(self):
return "MyType2()"
def MyVariable(name): def MyVariable(name):
return Variable(MyType(), None, None, name=name) return Variable(MyType(), None, None, name=name)
...@@ -123,13 +130,16 @@ class MyInnerGraphOp(Op, HasInnerGraph): ...@@ -123,13 +130,16 @@ class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = () __props__ = ()
def __init__(self, inner_inputs, inner_outputs): def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs input_replacements = [
self._inner_outputs = inner_outputs (v, NominalVariable(n, v.type))
for n, v in enumerate(inner_inputs)
if not isinstance(v, Constant)
]
outputs = clone_replace(inner_outputs, replace=input_replacements)
_, inputs = zip(*input_replacements) if input_replacements else (None, [])
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
def make_node(self, *inputs): def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [inputs[0].type()] outputs = [inputs[0].type()]
return Apply(self, list(inputs), outputs) return Apply(self, list(inputs), outputs)
...@@ -142,8 +152,8 @@ class MyInnerGraphOp(Op, HasInnerGraph): ...@@ -142,8 +152,8 @@ class MyInnerGraphOp(Op, HasInnerGraph):
@property @property
def inner_inputs(self): def inner_inputs(self):
return self._inner_inputs return self.fgraph.inputs
@property @property
def inner_outputs(self): def inner_outputs(self):
return self._inner_outputs return self.fgraph.outputs
...@@ -326,8 +326,8 @@ Inner graphs: ...@@ -326,8 +326,8 @@ Inner graphs:
MyInnerGraphOp [id A] MyInnerGraphOp [id A]
>op2 [id D] 'igo1' >op2 [id D] 'igo1'
> |4 [id E] > |*0-<MyType()> [id E]
> |5 [id F] > |*1-<MyType()> [id F]
""" """
for exp_line, res_line in zip(exp_res.split("\n"), lines): for exp_line, res_line in zip(exp_res.split("\n"), lines):
...@@ -349,13 +349,13 @@ Inner graphs: ...@@ -349,13 +349,13 @@ Inner graphs:
MyInnerGraphOp [id A] MyInnerGraphOp [id A]
>MyInnerGraphOp [id C] >MyInnerGraphOp [id C]
> |3 [id D] > |*0-<MyType()> [id D]
> |4 [id E] > |*1-<MyType()> [id E]
MyInnerGraphOp [id C] MyInnerGraphOp [id C]
>op2 [id F] 'igo1' >op2 [id F] 'igo1'
> |4 [id G] > |*0-<MyType()> [id D]
> |5 [id H] > |*1-<MyType()> [id E]
""" """
for exp_line, res_line in zip(exp_res.split("\n"), lines): for exp_line, res_line in zip(exp_res.split("\n"), lines):
...@@ -368,7 +368,6 @@ def test_get_var_by_id(): ...@@ -368,7 +368,6 @@ def test_get_var_by_id():
o1 = MyOp("op1")(r1, r2) o1 = MyOp("op1")(r1, r2)
o1.name = "o1" o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("v4") igo_in_1 = MyVariable("v4")
igo_in_2 = MyVariable("v5") igo_in_2 = MyVariable("v5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2) igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
...@@ -379,21 +378,6 @@ def test_get_var_by_id(): ...@@ -379,21 +378,6 @@ def test_get_var_by_id():
r3 = MyVariable("v3") r3 = MyVariable("v3")
o2 = igo(r3, o1) o2 = igo(r3, o1)
# import aesara; aesara.dprint([o1, o2])
# op1 [id A] 'o1'
# |1 [id B]
# |2 [id C]
# MyInnerGraphOp [id D]
# |3 [id E]
# |op1 [id A] 'o1'
#
# Inner graphs:
#
# MyInnerGraphOp [id D]
# >op2 [id F] 'igo1'
# > |4 [id G]
# > |5 [id H]
res = get_node_by_id(o1, "blah") res = get_node_by_id(o1, "blah")
assert res is None assert res is None
...@@ -404,7 +388,8 @@ def test_get_var_by_id(): ...@@ -404,7 +388,8 @@ def test_get_var_by_id():
res = get_node_by_id([o1, o2], "F") res = get_node_by_id([o1, o2], "F")
assert res == igo_out_1.owner exp_res = igo.fgraph.outputs[0].owner
assert res == exp_res
def test_PatternPrinter(): def test_PatternPrinter():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论