Unverified 提交 4355a58f authored 作者: Anoop T P's avatar Anoop T P 提交者: GitHub

Change FunctionGraph.change_input to change_node_input (#734)

上级 e476ebcd
......@@ -1221,11 +1221,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# and not(wrapped_outputs[i].borrow and wrapped_outputs[j].borrow):
if fgraph.outputs[j] in views_of_output_i:
if wrapped_outputs[i].borrow and wrapped_outputs[j].borrow:
fgraph.change_input(
fgraph.change_node_input(
"output", i, view_op(fgraph.outputs[i]), reason=reason
)
else:
fgraph.change_input(
fgraph.change_node_input(
"output", i, deep_copy_op(fgraph.outputs[i]), reason=reason
)
copied = True
......@@ -1248,7 +1248,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
if input_j in fgraph.inputs:
j = fgraph.inputs.index(input_j)
if wrapped_outputs[i].borrow and wrapped_inputs[j].borrow:
fgraph.change_input(
fgraph.change_node_input(
"output",
i,
view_op(fgraph.outputs[i]),
......@@ -1256,7 +1256,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
)
break
else:
fgraph.change_input(
fgraph.change_node_input(
"output",
i,
deep_copy_op(fgraph.outputs[i]),
......@@ -1264,7 +1264,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
)
break
elif wrapped_outputs[i].borrow:
fgraph.change_input(
fgraph.change_node_input(
"output",
i,
view_op(fgraph.outputs[i]),
......@@ -1272,7 +1272,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
)
break
else:
fgraph.change_input(
fgraph.change_node_input(
"output",
i,
deep_copy_op(fgraph.outputs[i]),
......
......@@ -366,7 +366,7 @@ class LambdaExtract:
self.reason = reason
def __call__(self):
return self.fgraph.change_input(
return self.fgraph.change_node_input(
self.node, self.i, self.r, reason=("Revert", self.reason)
)
......
......@@ -418,7 +418,7 @@ class FunctionGraph(MetaObject):
self.add_client(input, (node, i))
self.execute_callbacks("on_import", node, reason)
def change_input(
def change_node_input(
self,
node: Union[Apply, str],
i: int,
......@@ -544,7 +544,7 @@ class FunctionGraph(MetaObject):
assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var
)
self.change_input(
self.change_node_input(
node, i, new_var, reason=reason, import_missing=import_missing
)
......
......@@ -193,24 +193,24 @@ class TestFunctionGraph:
var6 = MyVariable2("var6")
with pytest.raises(TypeError):
fg.change_input("output", 1, var6)
fg.change_node_input("output", 1, var6)
with pytest.raises(TypeError):
fg.change_input(var5.owner, 1, var6)
fg.change_node_input(var5.owner, 1, var6)
old_apply_nodes = set(fg.apply_nodes)
old_variables = set(fg.variables)
old_var5_clients = list(fg.get_clients(var5))
# We're replacing with the same variable, so nothing should happen
fg.change_input(var5.owner, 1, var2)
fg.change_node_input(var5.owner, 1, var2)
assert old_apply_nodes == fg.apply_nodes
assert old_variables == fg.variables
assert old_var5_clients == fg.get_clients(var5)
# Perform a valid `Apply` node input change
fg.change_input(var5.owner, 1, var1)
fg.change_node_input(var5.owner, 1, var1)
assert var5.owner.inputs[1] is var1
assert (var5.owner, 1) not in fg.get_clients(var2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论