提交 01020a18 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add memo option, automatic inputs, and copy options to FunctionGraph

上级 52bad109
差异被折叠。
...@@ -41,6 +41,10 @@ class TestFunctionGraph: ...@@ -41,6 +41,10 @@ class TestFunctionGraph:
var3 = op1(var1) var3 = op1(var1)
FunctionGraph([var3], [var2], clone=False) FunctionGraph([var3], [var2], clone=False)
with pytest.raises(ValueError):
var3 = op1(var1)
FunctionGraph([var3], clone=False)
def test_init(self): def test_init(self):
var1 = MyVariable("var1") var1 = MyVariable("var1")
var2 = MyVariable("var2") var2 = MyVariable("var2")
...@@ -58,6 +62,19 @@ class TestFunctionGraph: ...@@ -58,6 +62,19 @@ class TestFunctionGraph:
assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)] assert fg.get_clients(var3) == [(var4.owner, 0), ("output", 0)]
assert fg.get_clients(var4) == [("output", 1)] assert fg.get_clients(var4) == [("output", 1)]
fg = FunctionGraph(outputs=[var3, var4], clone=False)
assert fg.inputs == [var1, var2]
memo = {}
fg = FunctionGraph(outputs=[var3, var4], clone=True, memo=memo)
assert memo[var1].type == var1.type
assert memo[var1].name == var1.name
assert memo[var2].type == var2.type
assert memo[var2].name == var2.name
assert var3 in memo
assert var4 in memo
def test_remove_client(self): def test_remove_client(self):
var1 = MyVariable("var1") var1 = MyVariable("var1")
var2 = MyVariable("var2") var2 = MyVariable("var2")
......
...@@ -58,7 +58,7 @@ class MyOp(Op): ...@@ -58,7 +58,7 @@ class MyOp(Op):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
outputs[0] = np.array(inputs) outputs[0] = np.array(inputs, dtype=np.object)
def __str__(self): def __str__(self):
return self.name return self.name
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论