提交 df09663a authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make FunctionGraph feature classes inherit from the Feature class

上级 d0e4dfc1
...@@ -21,8 +21,9 @@ import numpy as np ...@@ -21,8 +21,9 @@ import numpy as np
import theano import theano
from theano import config from theano import config
from theano.gof import graph, op, toolbox, unify from theano.gof import graph, op, unify
from theano.gof.fg import InconsistencyError from theano.gof.fg import InconsistencyError
from theano.gof.toolbox import Feature, NodeFinder
from theano.gof.utils import AssocList, flatten from theano.gof.utils import AssocList, flatten
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -451,12 +452,11 @@ class SeqOptimizer(GlobalOptimizer, UserList): ...@@ -451,12 +452,11 @@ class SeqOptimizer(GlobalOptimizer, UserList):
) )
class MergeFeature: class MergeFeature(Feature):
""" """Keeps track of variables in a `FunctionGraph` that cannot be merged together.
Keeps track of variables in fgraph that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge That way, the `MergeOptimizer` can remember the result of the last
pass on the fgraph. merge-pass on the `FunctionGraph`.
""" """
...@@ -1815,14 +1815,7 @@ class PatternSub(LocalOptimizer): ...@@ -1815,14 +1815,7 @@ class PatternSub(LocalOptimizer):
) )
################## class Updater(Feature):
# Navigators #
##################
# Use the following classes to apply LocalOptimizers
class Updater:
def __init__(self, importer, pruner, chin, name=None): def __init__(self, importer, pruner, chin, name=None):
self.importer = importer self.importer = importer
self.pruner = pruner self.pruner = pruner
...@@ -2070,7 +2063,7 @@ class NavigatorOptimizer(GlobalOptimizer): ...@@ -2070,7 +2063,7 @@ class NavigatorOptimizer(GlobalOptimizer):
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super().add_requirements(fgraph) super().add_requirements(fgraph)
# Added by default # Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate()) # fgraph.attach_feature(ReplaceValidate())
if self.local_opt: if self.local_opt:
self.local_opt.add_requirements(fgraph) self.local_opt.add_requirements(fgraph)
...@@ -2284,10 +2277,10 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -2284,10 +2277,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
""" """
super().add_requirements(fgraph) super().add_requirements(fgraph)
fgraph.attach_feature(toolbox.NodeFinder()) fgraph.attach_feature(NodeFinder())
class ChangeTracker: class ChangeTracker(Feature):
def __init__(self): def __init__(self):
self.changed = False self.changed = False
self.nb_imported = 0 self.nb_imported = 0
...@@ -3197,7 +3190,7 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): ...@@ -3197,7 +3190,7 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
return True return True
class CheckStackTraceFeature: class CheckStackTraceFeature(Feature):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
# In optdb we only register the CheckStackTraceOptimization when # In optdb we only register the CheckStackTraceOptimization when
# theano.config.check_stack_trace is not off but we also double check here. # theano.config.check_stack_trace is not off but we also double check here.
......
...@@ -982,7 +982,7 @@ class MakeVectorPrinter: ...@@ -982,7 +982,7 @@ class MakeVectorPrinter:
tt.pprint.assign(MakeVector, MakeVectorPrinter()) tt.pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature: class ShapeFeature(toolbox.Feature):
"""Graph optimizer for removing all calls to shape(). """Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with This optimizer replaces all Shapes and Subtensors of Shapes with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论