Unverified 提交 4355a58f authored 作者: Anoop T P's avatar Anoop T P 提交者: GitHub

Change FunctionGraph.change_input to change_node_input (#734)

上级 e476ebcd
...@@ -1221,11 +1221,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1221,11 +1221,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow): # and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
if fgraph.outputs[j] in views_of_output_i: if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow: if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_input( fgraph.change_node_input(
"output", i, view_op(fgraph.outputs[i]), reason=reason "output", i, view_op(fgraph.outputs[i]), reason=reason
) )
else: else:
fgraph.change_input( fgraph.change_node_input(
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason "output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
) )
copied = True copied = True
...@@ -1248,7 +1248,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1248,7 +1248,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if input_j in fgraph.inputs: if input_j in fgraph.inputs:
j = fgraph.inputs.index(input_j) j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow: if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
fgraph.change_input( fgraph.change_node_input(
"output", "output",
i, i,
view_op(fgraph.outputs[i]), view_op(fgraph.outputs[i]),
...@@ -1256,7 +1256,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1256,7 +1256,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
) )
break break
else: else:
fgraph.change_input( fgraph.change_node_input(
"output", "output",
i, i,
deep_copy_op(fgraph.outputs[i]), deep_copy_op(fgraph.outputs[i]),
...@@ -1264,7 +1264,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1264,7 +1264,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
) )
break break
elif wrapped_outputs[i].borrow: elif wrapped_outputs[i].borrow:
fgraph.change_input( fgraph.change_node_input(
"output", "output",
i, i,
view_op(fgraph.outputs[i]), view_op(fgraph.outputs[i]),
...@@ -1272,7 +1272,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1272,7 +1272,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
) )
break break
else: else:
fgraph.change_input( fgraph.change_node_input(
"output", "output",
i, i,
deep_copy_op(fgraph.outputs[i]), deep_copy_op(fgraph.outputs[i]),
......
...@@ -366,7 +366,7 @@ class LambdaExtract: ...@@ -366,7 +366,7 @@ class LambdaExtract:
self.reason = reason self.reason = reason
def __call__(self): def __call__(self):
return self.fgraph.change_input( return self.fgraph.change_node_input(
self.node, self.i, self.r, reason=("Revert", self.reason) self.node, self.i, self.r, reason=("Revert", self.reason)
) )
......
...@@ -418,7 +418,7 @@ class FunctionGraph(MetaObject): ...@@ -418,7 +418,7 @@ class FunctionGraph(MetaObject):
self.add_client(input, (node, i)) self.add_client(input, (node, i))
self.execute_callbacks("on_import", node, reason) self.execute_callbacks("on_import", node, reason)
def change_input( def change_node_input(
self, self,
node: Union[Apply, str], node: Union[Apply, str],
i: int, i: int,
...@@ -544,7 +544,7 @@ class FunctionGraph(MetaObject): ...@@ -544,7 +544,7 @@ class FunctionGraph(MetaObject):
assert (node == "output" and self.outputs[i] is var) or ( assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var node.inputs[i] is var
) )
self.change_input( self.change_node_input(
node, i, new_var, reason=reason, import_missing=import_missing node, i, new_var, reason=reason, import_missing=import_missing
) )
......
...@@ -193,24 +193,24 @@ class TestFunctionGraph: ...@@ -193,24 +193,24 @@ class TestFunctionGraph:
var6 = MyVariable2("var6") var6 = MyVariable2("var6")
with pytest.raises(TypeError): with pytest.raises(TypeError):
fg.change_input("output", 1, var6) fg.change_node_input("output", 1, var6)
with pytest.raises(TypeError): with pytest.raises(TypeError):
fg.change_input(var5.owner, 1, var6) fg.change_node_input(var5.owner, 1, var6)
old_apply_nodes = set(fg.apply_nodes) old_apply_nodes = set(fg.apply_nodes)
old_variables = set(fg.variables) old_variables = set(fg.variables)
old_var5_clients = list(fg.get_clients(var5)) old_var5_clients = list(fg.get_clients(var5))
# We're replacing with the same variable, so nothing should happen # We're replacing with the same variable, so nothing should happen
fg.change_input(var5.owner, 1, var2) fg.change_node_input(var5.owner, 1, var2)
assert old_apply_nodes == fg.apply_nodes assert old_apply_nodes == fg.apply_nodes
assert old_variables == fg.variables assert old_variables == fg.variables
assert old_var5_clients == fg.get_clients(var5) assert old_var5_clients == fg.get_clients(var5)
# Perform a valid `Apply` node input change # Perform a valid `Apply` node input change
fg.change_input(var5.owner, 1, var1) fg.change_node_input(var5.owner, 1, var1)
assert var5.owner.inputs[1] is var1 assert var5.owner.inputs[1] is var1
assert (var5.owner, 1) not in fg.get_clients(var2) assert (var5.owner, 1) not in fg.get_clients(var2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论