提交 157ac1e2 authored 作者: sentient07's avatar sentient07

Updated the new opts for new node and few more fixes

上级 79716b66
...@@ -150,15 +150,15 @@ optdb = gof.SequenceDB() ...@@ -150,15 +150,15 @@ optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), optdb.register('merge1', gof.MergeOptimizer(),
0, 'fast_run', 'fast_compile', 'merge') 0, 'fast_run', 'fast_compile', 'merge')
local_useless = gof.optdb.LocalGroupDB(apply_all_opts=True)
optdb.register('useless', gof.optdb.TopoDB(local_useless),
0.6, 'fast_run', 'fast_compile')
# After scan1 opt at 0.5 and before ShapeOpt at 1 # After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes. # This should only remove nodes.
# The opt should not do anything that need shape inference. # The opt should not do anything that need shape inference.
# New nodes that don't have infer_shape need that the original node # New nodes that don't have infer_shape need that the original node
# also don't have infer_shape # also don't have infer_shape
local_useless = gof.optdb.LocalGroupDB(apply_all_opts=True)
optdb.register('useless', gof.optdb.TopoDB(local_useless),
0.6, 'fast_run', 'fast_compile')
optdb.register('merge1.1', gof.MergeOptimizer(), optdb.register('merge1.1', gof.MergeOptimizer(),
0.65, 'fast_run', 'fast_compile', 'merge') 0.65, 'fast_run', 'fast_compile', 'merge')
......
...@@ -1249,8 +1249,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1249,8 +1249,9 @@ class LocalOptGroup(LocalOptimizer):
self.retains_inputs = all(getattr(opt, 'retains_inputs', False) self.retains_inputs = all(getattr(opt, 'retains_inputs', False)
for opt in optimizers) for opt in optimizers)
self.apply_all_opts = kwargs.get('apply_all_opts', False) self.apply_all_opts = kwargs.pop('apply_all_opts', False)
self.track_map = OrderedDict() self.track_map = OrderedDict()
assert len(kwargs) == 0
for o in self.opts: for o in self.opts:
for c in o.tracks(): for c in o.tracks():
...@@ -1273,6 +1274,12 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1273,6 +1274,12 @@ class LocalOptGroup(LocalOptimizer):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
def compute_opts(node):
opts = self.track_map.get(type(node.op), [])
opts += self.track_map.get(node.op, [])
opts += self.track_map.get(None, [])
return opts
def apply_mult_opts(opt_list, node, multiple_opts=False): def apply_mult_opts(opt_list, node, multiple_opts=False):
repl = False repl = False
for opt in opt_list: for opt in opt_list:
...@@ -1286,12 +1293,11 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1286,12 +1293,11 @@ class LocalOptGroup(LocalOptimizer):
# Ensuring not the input of graph # Ensuring not the input of graph
assert repl[0].owner assert repl[0].owner
new_node = repl[0].owner new_node = repl[0].owner
apply_mult_opts(opt_list, new_node, True) new_opts = compute_opts(new_node)
apply_mult_opts(new_opts, new_node, True)
return repl return repl
opts = self.track_map.get(type(node.op), [])
opts += self.track_map.get(node.op, []) return apply_mult_opts(compute_opts(node), node, self.apply_all_opts)
opts += self.track_map.get(None, [])
return apply_mult_opts(opts, node, self.apply_all_opts)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
......
...@@ -417,10 +417,7 @@ class LocalGroupDB(DB): ...@@ -417,10 +417,7 @@ class LocalGroupDB(DB):
class TopoDB(DB): class TopoDB(DB):
""" """
Generate a local optimizer of type LocalOptGroup instead Generate a local optimizer of type TopoOptimizer.
of a global optimizer.
It supports the tracks, to only get applied to some Op.
""" """
......
...@@ -1749,6 +1749,9 @@ def local_useless_fill(node): ...@@ -1749,6 +1749,9 @@ def local_useless_fill(node):
return [v] return [v]
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless @register_useless
@gof.local_optimizer([T.alloc]) @gof.local_optimizer([T.alloc])
def local_useless_alloc(node): def local_useless_alloc(node):
...@@ -1929,7 +1932,7 @@ def local_subtensor_remove_broadcastable_index(node): ...@@ -1929,7 +1932,7 @@ def local_subtensor_remove_broadcastable_index(node):
@register_specialize @register_specialize
@register_canonicalize('fast_compile_gpu') @register_canonicalize('fast_compile_gpu')
# @register_useless @register_useless
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
""" """
...@@ -4847,10 +4850,6 @@ class Canonizer(gof.LocalOptimizer): ...@@ -4847,10 +4850,6 @@ class Canonizer(gof.LocalOptimizer):
assert len(node.outputs) == 1 assert len(node.outputs) == 1
out = node.outputs[0] out = node.outputs[0]
# Condition for replacement variable not being a part of the graph
if not hasattr(out, 'clients'):
return False
# check if any of the clients of this node would be part of # check if any of the clients of this node would be part of
# this canonized graph... if so, we do nothing and wait for # this canonized graph... if so, we do nothing and wait for
# them to be transformed. # them to be transformed.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论