提交 df09663a authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make FunctionGraph feature classes inherit from the Feature class

上级 d0e4dfc1
......@@ -21,8 +21,9 @@ import numpy as np
import theano
from theano import config
from theano.gof import graph, op, toolbox, unify
from theano.gof import graph, op, unify
from theano.gof.fg import InconsistencyError
from theano.gof.toolbox import Feature, NodeFinder
from theano.gof.utils import AssocList, flatten
from theano.misc.ordered_set import OrderedSet
......@@ -451,12 +452,11 @@ class SeqOptimizer(GlobalOptimizer, UserList):
)
class MergeFeature:
"""
Keeps track of variables in fgraph that cannot be merged together.
class MergeFeature(Feature):
"""Keeps track of variables in a `FunctionGraph` that cannot be merged together.
That way, the MergeOptimizer can remember the result of the last merge
pass on the fgraph.
That way, the `MergeOptimizer` can remember the result of the last
merge-pass on the `FunctionGraph`.
"""
......@@ -1815,14 +1815,7 @@ class PatternSub(LocalOptimizer):
)
##################
# Navigators #
##################
# Use the following classes to apply LocalOptimizers
class Updater:
class Updater(Feature):
def __init__(self, importer, pruner, chin, name=None):
self.importer = importer
self.pruner = pruner
......@@ -2070,7 +2063,7 @@ class NavigatorOptimizer(GlobalOptimizer):
def add_requirements(self, fgraph):
super().add_requirements(fgraph)
# Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate())
# fgraph.attach_feature(ReplaceValidate())
if self.local_opt:
self.local_opt.add_requirements(fgraph)
......@@ -2284,10 +2277,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
"""
super().add_requirements(fgraph)
fgraph.attach_feature(toolbox.NodeFinder())
fgraph.attach_feature(NodeFinder())
class ChangeTracker:
class ChangeTracker(Feature):
def __init__(self):
self.changed = False
self.nb_imported = 0
......@@ -3197,7 +3190,7 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
return True
class CheckStackTraceFeature:
class CheckStackTraceFeature(Feature):
def on_import(self, fgraph, node, reason):
# In optdb we only register the CheckStackTraceOptimization when
# theano.config.check_stack_trace is not off but we also double check here.
......
......@@ -982,7 +982,7 @@ class MakeVectorPrinter:
tt.pprint.assign(MakeVector, MakeVectorPrinter())
class ShapeFeature:
class ShapeFeature(toolbox.Feature):
"""Graph optimizer for removing all calls to shape().
This optimizer replaces all Shapes and Subtensors of Shapes with
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论