提交 37444647 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to add missing inputs during FunctionGraph operations

上级 334d3fdf
差异被折叠。
......@@ -856,9 +856,9 @@ class MergeOptimizer(GlobalOptimizer):
# Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them.
if all([isinstance(old, Constant) for old, new in pairs]):
fgraph.replace_all(pairs, "MergeOptimizer")
fgraph.replace_all(pairs, reason="MergeOptimizer")
else:
fgraph.replace_all_validate(pairs, "MergeOptimizer")
fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
except InconsistencyError:
success = False
nb_fail += 1
......
......@@ -555,10 +555,12 @@ class ReplaceValidate(History, Validator):
del fgraph.replace_all_validate
del fgraph.replace_all_validate_remove
def replace_validate(self, fgraph, r, new_r, reason=None):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason)
def replace_validate(self, fgraph, r, new_r, reason=None, **kwargs):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason, **kwargs)
def replace_all_validate(self, fgraph, replacements, reason=None, verbose=None):
def replace_all_validate(
self, fgraph, replacements, reason=None, verbose=None, **kwargs
):
chk = fgraph.checkpoint()
if verbose is None:
verbose = config.optimizer_verbose
......@@ -569,7 +571,7 @@ class ReplaceValidate(History, Validator):
for r, new_r in replacements:
try:
fgraph.replace(r, new_r, reason=reason, verbose=False)
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
except Exception as e:
msg = str(e)
s1 = "The type of the replacement must be the same"
......@@ -630,14 +632,14 @@ class ReplaceValidate(History, Validator):
return chk
def replace_all_validate_remove(
self, fgraph, replacements, remove, reason=None, warn=True
self, fgraph, replacements, remove, reason=None, warn=True, **kwargs
):
"""
As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. Also print a warning.
"""
chk = fgraph.replace_all_validate(replacements, reason)
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
self._nodes_removed.update(remove)
for rm in remove:
if rm in fgraph.apply_nodes or rm in fgraph.variables:
......
......@@ -111,20 +111,26 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var5 = MyVariable("var5")
var6 = op2(var5)
var8 = MyVariable("var8")
var6 = op2(var8)
with pytest.raises(MissingInputError):
fg.import_node(var6.owner)
var6 = op2(var2)
assert not hasattr(var6.owner.tag, "imported_by")
fg.import_node(var6.owner)
assert var8 not in fg.variables
assert hasattr(var6.owner.tag, "imported_by")
assert var6 in fg.variables
fg.import_node(var6.owner, import_missing=True)
assert var8 in fg.inputs
assert var6.owner in fg.apply_nodes
assert (var6.owner, 0) in fg.get_clients(var2)
var7 = op2(var2)
assert not hasattr(var7.owner.tag, "imported_by")
fg.import_node(var7.owner)
assert hasattr(var7.owner.tag, "imported_by")
assert var7 in fg.variables
assert var7.owner in fg.apply_nodes
assert (var7.owner, 0) in fg.get_clients(var2)
def test_import_var(self):
......@@ -135,12 +141,17 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var0 = MyVariable("var0")
with pytest.raises(MissingInputError):
var0 = MyVariable("var0")
# We can't import a new `FunctionGraph` input (i.e. something
# without an owner)
# without an owner), at least not without setting `import_missing`
fg.import_var(var0, "testing")
fg.import_var(var0, import_missing=True)
assert var0 in fg.inputs
var5 = op2()
# We can import variables with owners
fg.import_var(var5, "testing")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论