提交 82d158a2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

clone_with_equiv, changes to Apply.clone_with_new_inputs

上级 2c813862
...@@ -4,7 +4,7 @@ from env import InconsistencyError, Env ...@@ -4,7 +4,7 @@ from env import InconsistencyError, Env
from ext import DestroyHandler, view_roots from ext import DestroyHandler, view_roots
from graph import Apply, Result, Constant, Value from graph import Apply, Result, Constant, Value
from link import Linker, LocalLinker, PerformLinker, MetaLinker, Profiler from link import Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
from op import Op from op import Op, Macro
from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge
from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from type import Type, Generic, generic from type import Type, Generic, generic
......
...@@ -76,21 +76,36 @@ class Apply(utils.object2): ...@@ -76,21 +76,36 @@ class Apply(utils.object2):
cp = self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs]) cp = self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs])
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
def clone_with_new_inputs(self, inputs, check_type = True): def clone_with_new_inputs(self, inputs, strict = True):
""" """
Returns an Apply node with the same op but different inputs. Unless Returns an Apply node with the same op but different inputs. Unless
check_type is False, the type fields of all the inputs must be strict is False, the type fields of all the inputs must be
equal to the current ones. equal to the current ones.
The outputs of the clone will have the same type as the outputs of If strict is True, the outputs of the clone will have the same type as
self. the outputs of self. Else, it depends on the types of the new inputs
and the behavior of the op wrt that.
""" """
if check_type: # if check_type:
for curr, new in zip(self.inputs, inputs): # for curr, new in zip(self.inputs, inputs):
if not curr.type == new.type: # if not curr.type == new.type:
# raise TypeError("Cannot change the type of this input.", curr, new)
# new_node = self.clone()
# new_node.inputs = inputs
# return new_node
remake_node = False
for curr, new in zip(self.inputs, inputs):
if not curr.type == new.type:
if strict:
raise TypeError("Cannot change the type of this input.", curr, new) raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone() else:
new_node.inputs = inputs remake_node = True
if remake_node:
new_node = self.op.make_node(*inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
else:
new_node = self.clone()
new_node.inputs = inputs
return new_node return new_node
nin = property(lambda self: len(self.inputs), doc = 'same as len(self.inputs)') nin = property(lambda self: len(self.inputs), doc = 'same as len(self.inputs)')
...@@ -367,6 +382,94 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -367,6 +382,94 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
return d return d
# def clone_with_new_inputs(i, o, new_i):
# equiv = clone_with_new_inputs_get_equiv(i, o, new_i)
# return [equiv[input] for input in i], [equiv[output] for output in o]
# def clone_with_new_inputs_get_equiv(i, o, new_i, copy_orphans = True):
# # note: this does not exactly mirror Apply.clone_with_new_inputs
# # here it is possible to give different types to new_i and then
# # make_node is called on the ops instead of clone_with_new_inputs
# # whenever the type is different.
# d = {}
# for input, new_input in zip(i, new_i):
# d[input] = new_input
# def clone_helper(result):
# if result in d:
# return d[result]
# node = result.owner
# if node is None: # result is an orphan
# if copy_orphans:
# cpy = result.clone()
# d[result] = cpy
# else:
# d[result] = result
# return d[result]
# else:
# cloned_inputs = [clone_helper(input) for input in node.inputs]
# if any(input != cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
# new_node = node.op.make_node(*cloned_inputs)
# else:
# new_node = node.clone_with_new_inputs(cloned_inputs)
# d[node] = new_node
# for output, new_output in zip(node.outputs, new_node.outputs):
# d[output] = new_output
# return d[result]
# for output in o:
# clone_helper(output)
# return d
def clone_with_equiv(i, o, d, missing_input_policy = 'fail', orphan_policy = 'copy'):
def clone_helper(result):
if result in d:
return d[result]
node = result.owner
if node is None: # result is an input or an orphan not in d
if isinstance(result, Value):
if orphan_policy == 'copy':
d[result] = copy(result)
elif orphan_policy == 'keep':
d[result] = result
else:
raise ValueError("unknown orphan_policy: '%s'" % orphan_policy)
else:
if missing_input_policy == 'fail':
raise ValueError("missing input: %s" % result)
elif missing_input_policy == 'keep':
d[result] = result
else:
raise ValueError("unknown missing_input_policy: '%s'" % missing_input_policy)
return d[result]
else:
cloned_inputs = [clone_helper(input) for input in node.inputs]
if all(input is cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
new_node = node
else:
new_node = node.clone_with_new_inputs(cloned_inputs, strict = False)
# if any(input != cloned_input for input, cloned_input in zip(node.inputs, cloned_inputs)):
# new_node = node.op.make_node(*cloned_inputs)
# else:
# new_node = node.clone_with_new_inputs(cloned_inputs)
d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output
return d[result]
for output in o:
clone_helper(output)
return [d[input] for input in i], [d[output] for output in o]
def general_toposort(r_out, deps): def general_toposort(r_out, deps):
""" """
@note: deps(i) should behave like a pure function (no funny business with @note: deps(i) should behave like a pure function (no funny business with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论