提交 ae2c7bdb authored 作者: Frederic's avatar Frederic

local_merge_optimizer now register correctly the MergeFeature

上级 73f7b1ce
...@@ -1165,7 +1165,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -1165,7 +1165,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
id(self)), file=stream) id(self)), file=stream)
def local_optimizer(tracks, inplace=False): def local_optimizer(tracks, inplace=False, requirements=()):
def decorator(f): def decorator(f):
""" """
WRITEME WRITEME
...@@ -1177,12 +1177,13 @@ def local_optimizer(tracks, inplace=False): ...@@ -1177,12 +1177,13 @@ def local_optimizer(tracks, inplace=False):
for t in tracks: for t in tracks:
if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)): if not (isinstance(t, op.Op) or issubclass(t, op.PureOp)):
raise ValueError("Tracks are op classes or instances", f.__module__, f.__name__) raise ValueError("Tracks are op classes or instances", f.__module__, f.__name__)
requirements = () req = requirements
if inplace: if inplace:
dh_handler = dh.DestroyHandler dh_handler = dh.DestroyHandler
requirements = (lambda fgraph: req = tuple(requirements) + (
lambda fgraph:
fgraph.attach_feature(dh_handler()),) fgraph.attach_feature(dh_handler()),)
rval = FromFunctionLocalOptimizer(f, tracks, requirements) rval = FromFunctionLocalOptimizer(f, tracks, req)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
return decorator return decorator
......
...@@ -514,8 +514,13 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -514,8 +514,13 @@ def register_specialize_device(lopt, *tags, **kwargs):
# 2) after an local optimization being applied, if the # 2) after an local optimization being applied, if the
# current node is still in the graph, it will continue to the next # current node is still in the graph, it will continue to the next
# local optimizer. So this won't trigger more iteration. # local optimizer. So this won't trigger more iteration.
def add_merge_feature(fgraph):
if not hasattr(fgraph, 'merge_feature'):
fgraph.attach_feature(theano.gof.opt.MergeFeature())
@register_canonicalize('fast_compile', 'merge') @register_canonicalize('fast_compile', 'merge')
@gof.local_optimizer(None) @gof.local_optimizer(None, requirements=[add_merge_feature])
def local_merge_optimizer(node): def local_merge_optimizer(node):
if node.fgraph.merge_feature.scheduled: if node.fgraph.merge_feature.scheduled:
ret = merge_optimizer(node.fgraph) ret = merge_optimizer(node.fgraph)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论