提交 cc70b958 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Combine MyInnerGraphOp implementations and use FunctionGraphs and nominal variables

上级 3b0e071e
......@@ -26,7 +26,7 @@ from aesara.graph.basic import (
vars_between,
walk,
)
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.tensor.math import max_and_argmax
from aesara.tensor.type import (
......@@ -41,6 +41,7 @@ from aesara.tensor.type import (
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable
from tests import unittest_tools as utt
from tests.graph.utils import MyInnerGraphOp
class MyType(Type):
......@@ -85,36 +86,6 @@ class MyOp(Op):
MyOp = MyOp()
class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
return Apply(self, list(inputs), outputs)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
@property
def fn(self):
raise NotImplementedError("No Python implementation available.")
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
class X:
def leaf_formatter(self, leaf):
return str(leaf.type)
......@@ -550,7 +521,8 @@ def test_get_var_by_name():
(res,) = get_var_by_name([o1, o2], "igo1")
assert res == igo_out_1
exp_res = igo.fgraph.outputs[0]
assert res == exp_res
class TestCloneReplace:
......
import numpy as np
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.basic import Apply, Constant, NominalVariable, Variable, clone_replace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.type import Type
......@@ -21,6 +22,9 @@ class MyType(Type):
def __hash__(self):
return hash(MyType)
def __repr__(self):
return "MyType()"
class MyType2(Type):
def filter(self, data):
......@@ -32,6 +36,9 @@ class MyType2(Type):
def __hash__(self):
return hash(MyType)
def __repr__(self):
return "MyType2()"
def MyVariable(name):
return Variable(MyType(), None, None, name=name)
......@@ -123,13 +130,16 @@ class MyInnerGraphOp(Op, HasInnerGraph):
__props__ = ()
def __init__(self, inner_inputs, inner_outputs):
self._inner_inputs = inner_inputs
self._inner_outputs = inner_outputs
input_replacements = [
(v, NominalVariable(n, v.type))
for n, v in enumerate(inner_inputs)
if not isinstance(v, Constant)
]
outputs = clone_replace(inner_outputs, replace=input_replacements)
_, inputs = zip(*input_replacements) if input_replacements else (None, [])
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
def make_node(self, *inputs):
for input in inputs:
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [inputs[0].type()]
return Apply(self, list(inputs), outputs)
......@@ -142,8 +152,8 @@ class MyInnerGraphOp(Op, HasInnerGraph):
@property
def inner_inputs(self):
return self._inner_inputs
return self.fgraph.inputs
@property
def inner_outputs(self):
return self._inner_outputs
return self.fgraph.outputs
......@@ -326,8 +326,8 @@ Inner graphs:
MyInnerGraphOp [id A]
>op2 [id D] 'igo1'
> |4 [id E]
> |5 [id F]
> |*0-<MyType()> [id E]
> |*1-<MyType()> [id F]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
......@@ -349,13 +349,13 @@ Inner graphs:
MyInnerGraphOp [id A]
>MyInnerGraphOp [id C]
> |3 [id D]
> |4 [id E]
> |*0-<MyType()> [id D]
> |*1-<MyType()> [id E]
MyInnerGraphOp [id C]
>op2 [id F] 'igo1'
> |4 [id G]
> |5 [id H]
> |*0-<MyType()> [id D]
> |*1-<MyType()> [id E]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
......@@ -368,7 +368,6 @@ def test_get_var_by_id():
o1 = MyOp("op1")(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable("v4")
igo_in_2 = MyVariable("v5")
igo_out_1 = MyOp("op2")(igo_in_1, igo_in_2)
......@@ -379,21 +378,6 @@ def test_get_var_by_id():
r3 = MyVariable("v3")
o2 = igo(r3, o1)
# import aesara; aesara.dprint([o1, o2])
# op1 [id A] 'o1'
# |1 [id B]
# |2 [id C]
# MyInnerGraphOp [id D]
# |3 [id E]
# |op1 [id A] 'o1'
#
# Inner graphs:
#
# MyInnerGraphOp [id D]
# >op2 [id F] 'igo1'
# > |4 [id G]
# > |5 [id H]
res = get_node_by_id(o1, "blah")
assert res is None
......@@ -404,7 +388,8 @@ def test_get_var_by_id():
res = get_node_by_id([o1, o2], "F")
assert res == igo_out_1.owner
exp_res = igo.fgraph.outputs[0].owner
assert res == exp_res
def test_PatternPrinter():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论