提交 37444647 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to add missing inputs during FunctionGraph operations

上级 334d3fdf
差异被折叠。
...@@ -856,9 +856,9 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -856,9 +856,9 @@ class MergeOptimizer(GlobalOptimizer):
# Only need to check one of the var of each pairs. # Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them. # If it is a Constant, the other must also be a Constant as we merge them.
if all([isinstance(old, Constant) for old, new in pairs]): if all([isinstance(old, Constant) for old, new in pairs]):
fgraph.replace_all(pairs, "MergeOptimizer") fgraph.replace_all(pairs, reason="MergeOptimizer")
else: else:
fgraph.replace_all_validate(pairs, "MergeOptimizer") fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
except InconsistencyError: except InconsistencyError:
success = False success = False
nb_fail += 1 nb_fail += 1
......
...@@ -555,10 +555,12 @@ class ReplaceValidate(History, Validator): ...@@ -555,10 +555,12 @@ class ReplaceValidate(History, Validator):
del fgraph.replace_all_validate del fgraph.replace_all_validate
del fgraph.replace_all_validate_remove del fgraph.replace_all_validate_remove
def replace_validate(self, fgraph, r, new_r, reason=None): def replace_validate(self, fgraph, r, new_r, reason=None, **kwargs):
self.replace_all_validate(fgraph, [(r, new_r)], reason=reason) self.replace_all_validate(fgraph, [(r, new_r)], reason=reason, **kwargs)
def replace_all_validate(self, fgraph, replacements, reason=None, verbose=None): def replace_all_validate(
self, fgraph, replacements, reason=None, verbose=None, **kwargs
):
chk = fgraph.checkpoint() chk = fgraph.checkpoint()
if verbose is None: if verbose is None:
verbose = config.optimizer_verbose verbose = config.optimizer_verbose
...@@ -569,7 +571,7 @@ class ReplaceValidate(History, Validator): ...@@ -569,7 +571,7 @@ class ReplaceValidate(History, Validator):
for r, new_r in replacements: for r, new_r in replacements:
try: try:
fgraph.replace(r, new_r, reason=reason, verbose=False) fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
except Exception as e: except Exception as e:
msg = str(e) msg = str(e)
s1 = "The type of the replacement must be the same" s1 = "The type of the replacement must be the same"
...@@ -630,14 +632,14 @@ class ReplaceValidate(History, Validator): ...@@ -630,14 +632,14 @@ class ReplaceValidate(History, Validator):
return chk return chk
def replace_all_validate_remove( def replace_all_validate_remove(
self, fgraph, replacements, remove, reason=None, warn=True self, fgraph, replacements, remove, reason=None, warn=True, **kwargs
): ):
""" """
As replace_all_validate, revert the replacement if the ops As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. Also print a warning. in the list remove are still in the graph. Also print a warning.
""" """
chk = fgraph.replace_all_validate(replacements, reason) chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
self._nodes_removed.update(remove) self._nodes_removed.update(remove)
for rm in remove: for rm in remove:
if rm in fgraph.apply_nodes or rm in fgraph.variables: if rm in fgraph.apply_nodes or rm in fgraph.variables:
......
...@@ -111,20 +111,26 @@ class TestFunctionGraph: ...@@ -111,20 +111,26 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2) var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var5 = MyVariable("var5") var8 = MyVariable("var8")
var6 = op2(var5) var6 = op2(var8)
with pytest.raises(MissingInputError): with pytest.raises(MissingInputError):
fg.import_node(var6.owner) fg.import_node(var6.owner)
var6 = op2(var2) assert var8 not in fg.variables
assert not hasattr(var6.owner.tag, "imported_by")
fg.import_node(var6.owner)
assert hasattr(var6.owner.tag, "imported_by") fg.import_node(var6.owner, import_missing=True)
assert var6 in fg.variables assert var8 in fg.inputs
assert var6.owner in fg.apply_nodes assert var6.owner in fg.apply_nodes
assert (var6.owner, 0) in fg.get_clients(var2)
var7 = op2(var2)
assert not hasattr(var7.owner.tag, "imported_by")
fg.import_node(var7.owner)
assert hasattr(var7.owner.tag, "imported_by")
assert var7 in fg.variables
assert var7.owner in fg.apply_nodes
assert (var7.owner, 0) in fg.get_clients(var2)
def test_import_var(self): def test_import_var(self):
...@@ -135,12 +141,17 @@ class TestFunctionGraph: ...@@ -135,12 +141,17 @@ class TestFunctionGraph:
var5 = op3(var4, var2, var2) var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False) fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var0 = MyVariable("var0")
with pytest.raises(MissingInputError): with pytest.raises(MissingInputError):
var0 = MyVariable("var0")
# We can't import a new `FunctionGraph` input (i.e. something # We can't import a new `FunctionGraph` input (i.e. something
# without an owner) # without an owner), at least not without setting `import_missing`
fg.import_var(var0, "testing") fg.import_var(var0, "testing")
fg.import_var(var0, import_missing=True)
assert var0 in fg.inputs
var5 = op2() var5 = op2()
# We can import variables with owners # We can import variables with owners
fg.import_var(var5, "testing") fg.import_var(var5, "testing")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论