提交 55ca059a authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Change Optimizer to GlobalOptimizer

上级 48ac71e3
......@@ -57,7 +57,7 @@ Global optimization
A global optimization (or optimizer) is an object which defines the following
methods:
.. class:: Optimizer
.. class:: GlobalOptimizer
.. method:: apply(fgraph)
......@@ -75,7 +75,7 @@ methods:
This is the interface function called by Theano.
*Default:* this is defined by Optimizer as ``add_requirement(fgraph);
*Default:* this is defined by GlobalOptimizer as ``add_requirement(fgraph);
apply(fgraph)``.
See the section about :class:`FunctionGraph` to understand how to define these
......@@ -125,7 +125,7 @@ simplification described above:
from theano import gof
from theano.gof import toolbox
class Simplify(gof.Optimizer):
class Simplify(gof.GlobalOptimizer):
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
......@@ -471,7 +471,7 @@ Here are a few examples of how to use a Query on optdb to produce an
Optimizer:
.. testcode::
from theano.gof import Query
from theano.compile import optdb
......
......@@ -5,7 +5,7 @@ from theano.gof.optdb import DB, opt
class TestDB:
def test_name_clashes(self):
class Opt(opt.Optimizer): # inheritance buys __hash__
class Opt(opt.GlobalOptimizer): # inheritance buys __hash__
name = "blah"
db = DB()
......
......@@ -93,13 +93,13 @@ predefined_optimizers = {
def register_optimizer(name, opt):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`."""
"""Add a `GlobalOptimizer` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers:
raise ValueError(f"Optimizer name already taken: {name}")
predefined_optimizers[name] = opt
class AddDestroyHandler(gof.Optimizer):
class AddDestroyHandler(gof.GlobalOptimizer):
"""
This optimizer performs two important functions:
......@@ -134,7 +134,7 @@ class AddDestroyHandler(gof.Optimizer):
fgraph.attach_feature(gof.DestroyHandler())
class AddFeatureOptimizer(gof.Optimizer):
class AddFeatureOptimizer(gof.GlobalOptimizer):
"""
This optimizer adds a provided feature to the function graph.
"""
......@@ -147,7 +147,7 @@ class AddFeatureOptimizer(gof.Optimizer):
fgraph.attach_feature(self.feature)
class PrintCurrentFunctionGraph(gof.Optimizer):
class PrintCurrentFunctionGraph(gof.GlobalOptimizer):
"""
This optimizer is for debugging.
......
......@@ -24,6 +24,7 @@ from theano.gof.op import (
from theano.gof.opt import (
CheckStackTraceOptimization,
EquilibriumOptimizer,
GlobalOptimizer,
LocalOptGroup,
LocalOptimizer,
MergeOptimizer,
......@@ -31,7 +32,6 @@ from theano.gof.opt import (
OpKeyOptimizer,
OpRemove,
OpSub,
Optimizer,
PatternSub,
SeqOptimizer,
TopoOptimizer,
......
......@@ -43,10 +43,10 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""
class Optimizer:
class GlobalOptimizer:
"""
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
A L{GlobalOptimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}.
......@@ -73,7 +73,7 @@ class Optimizer:
Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an
L{GlobalOptimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
......@@ -125,11 +125,8 @@ class Optimizer:
)
class FromFunctionOptimizer(Optimizer):
"""
WRITEME
"""
class FromFunctionOptimizer(GlobalOptimizer):
"""A `GlobalOptimizer` constructed from a given function."""
def __init__(self, fn, requirements=()):
self.apply = fn
......@@ -171,14 +168,8 @@ def inplace_optimizer(f):
return rval
class SeqOptimizer(Optimizer, list):
# inherit from Optimizer first to get Optimizer.__hash__
"""
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
class SeqOptimizer(GlobalOptimizer, list):
"""A `GlobalOptimizer` that applies a list of optimizers sequentially."""
@staticmethod
def warn(exc, self, optimizer):
......@@ -214,7 +205,7 @@ class SeqOptimizer(Optimizer, list):
def apply(self, fgraph):
"""
Applies each L{Optimizer} in self in turn.
Applies each L{GlobalOptimizer} in self in turn.
"""
l = []
......@@ -823,7 +814,7 @@ class MergeFeature:
return new_inputs
class MergeOptimizer(Optimizer):
class MergeOptimizer(GlobalOptimizer):
"""
Merges parts of the graph that are identical and redundant.
......@@ -1945,7 +1936,7 @@ class Updater:
self.chin = None
class NavigatorOptimizer(Optimizer):
class NavigatorOptimizer(GlobalOptimizer):
"""
Abstract class.
......@@ -2835,7 +2826,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+ list(opt.final_optimizers)
+ list(opt.cleanup_optimizers)
)
if o.print_profile.__code__ is not Optimizer.print_profile.__code__
if o.print_profile.__code__ is not GlobalOptimizer.print_profile.__code__
]
if not gf_opts:
return
......@@ -3310,7 +3301,7 @@ class CheckStrackTraceFeature:
)
class CheckStackTraceOptimization(Optimizer):
class CheckStackTraceOptimization(GlobalOptimizer):
"""Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
def add_requirements(self, fgraph):
......
......@@ -43,10 +43,10 @@ class DB:
tags specified will enable that optimization.
"""
# N.B. obj is not an instance of class Optimizer.
# N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example,
# this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
if not isinstance(obj, (DB, opt.GlobalOptimizer, opt.LocalOptimizer)):
raise TypeError("Object cannot be registered in OptDB", obj)
if name in self.__db__:
raise ValueError(
......@@ -285,8 +285,8 @@ class EquilibriumDB(DB):
Notes
-----
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both.
We can use `LocalOptimizer` and `GlobalOptimizer` since `EquilibriumOptimizer`
supports both.
It is probably not a good idea to have ignore_newtrees=False and
tracks_on_change_inputs=True
......@@ -473,7 +473,7 @@ class LocalGroupDB(DB):
class TopoDB(DB):
"""
Generate a Global Optimizer of type TopoOptimizer.
Generate a `GlobalOptimizer` of type TopoOptimizer.
"""
......
import theano
from theano.compile import optdb
from theano.compile.ops import shape_i_op
from theano.gof.opt import Optimizer, inherit_stack_trace, local_optimizer
from theano.gof.opt import GlobalOptimizer, inherit_stack_trace, local_optimizer
from theano.gpuarray.basic_ops import (
GpuAllocEmpty,
GpuArrayType,
......@@ -817,7 +817,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs):
return [as_gpuarray_variable(arg.astype("int64"), ctx_name)]
class NoCuDNNRaise(Optimizer):
class NoCuDNNRaise(GlobalOptimizer):
def apply(self, fgraph):
"""
Raise a error if cudnn can't be used.
......
......@@ -15,7 +15,7 @@ from theano import config, gof, scalar, tensor
from theano.breakpoint import PdbBreakpoint
from theano.compile import optdb
from theano.compile.ops import shape_i
from theano.gof import Optimizer, graph, local_optimizer, toolbox
from theano.gof import GlobalOptimizer, graph, local_optimizer, toolbox
from theano.gof.opt import LocalMetaOptimizer, copy_stack_trace, inherit_stack_trace
from theano.gpuarray.basic_ops import (
GpuAlloc,
......@@ -210,7 +210,7 @@ gpu_neg = GpuElemwise(neg)
gpu_true_div = GpuElemwise(true_div)
class InputToGpuOptimizer(Optimizer):
class InputToGpuOptimizer(GlobalOptimizer):
"""
Transfer the input to the gpu to start the rolling wave.
......@@ -260,7 +260,7 @@ gpu_seqopt.register(
)
class GraphToGPU(Optimizer):
class GraphToGPU(GlobalOptimizer):
"""
Transfer the graph as a whole to GPU instead of transferring node by node.
......
......@@ -592,7 +592,7 @@ def cond_merge_ifs_false(node):
return op(*old_ins, **dict(return_list=True))
class CondMerge(gof.Optimizer):
class CondMerge(gof.GlobalOptimizer):
""" Graph Optimizer that merges different cond ops """
def add_requirements(self, fgraph):
......
......@@ -3,7 +3,7 @@ import logging
import theano.tensor
from theano import tensor
from theano.gof import Apply, Op, local_optimizer
from theano.gof.opt import Optimizer
from theano.gof.opt import GlobalOptimizer
from theano.tensor import DimShuffle, Dot
from theano.tensor.blas import Dot22
from theano.tensor.nlinalg import (
......@@ -171,13 +171,13 @@ class HintsFeature:
# 2) we are putting things back after a failed transaction.
class HintsOptimizer(Optimizer):
class HintsOptimizer(GlobalOptimizer):
"""
Optimizer that serves to add HintsFeature as an fgraph feature.
"""
def __init__(self):
Optimizer.__init__(self)
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(HintsFeature())
......
......@@ -63,7 +63,11 @@ from theano.compile import optdb
from theano.compile.function.types import deep_copy_op
from theano.gof import DestroyHandler, InconsistencyError, toolbox
from theano.gof.graph import equal_computations
from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer
from theano.gof.opt import (
GlobalOptimizer,
pre_constant_merge,
pre_greedy_local_optimizer,
)
from theano.scan.op import Scan
from theano.scan.utils import (
clone,
......@@ -224,14 +228,14 @@ def remove_constants_and_unused_inputs_scan(node):
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
class PushOutNonSeqScan(gof.Optimizer):
class PushOutNonSeqScan(gof.GlobalOptimizer):
"""
A global optimizer for pushing out the variables inside the scan that depend
only on non-sequences.
"""
def __init__(self):
gof.Optimizer.__init__(self)
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
......@@ -440,14 +444,14 @@ class PushOutNonSeqScan(gof.Optimizer):
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
class PushOutSeqScan(gof.Optimizer):
class PushOutSeqScan(gof.GlobalOptimizer):
"""
A global optimizer for pushing out the variables inside the
scan that depend only on constants and sequences.
"""
def __init__(self):
gof.Optimizer.__init__(self)
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
......@@ -696,14 +700,14 @@ class PushOutSeqScan(gof.Optimizer):
return False
class PushOutScanOutput(gof.Optimizer):
class PushOutScanOutput(gof.GlobalOptimizer):
"""
This is an optimization that can push operations performed
at the end of the inner graph of scan to outside of scan.
"""
def __init__(self):
gof.Optimizer.__init__(self)
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
......@@ -958,14 +962,14 @@ class PushOutScanOutput(gof.Optimizer):
return new_scan_node
class ScanInplaceOptimizer(Optimizer):
class ScanInplaceOptimizer(GlobalOptimizer):
"""
Graph optimizer for Scan (makes it run inplace).
"""
def __init__(self, typeInfer=None, gpua_flag=False):
Optimizer.__init__(self)
super().__init__()
self.typeInfer = typeInfer
self.gpua_flag = gpua_flag
......@@ -1124,14 +1128,14 @@ class ScanInplaceOptimizer(Optimizer):
node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops)
class ScanSaveMem(gof.Optimizer):
class ScanSaveMem(gof.GlobalOptimizer):
"""
Graph Optimizer that reduces scan memory consumption.
"""
def __init__(self):
gof.Optimizer.__init__(self)
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
......@@ -1680,7 +1684,7 @@ class ScanSaveMem(gof.Optimizer):
self.process_node(fgraph, node)
class ScanMerge(gof.Optimizer):
class ScanMerge(gof.GlobalOptimizer):
"""
Graph Optimizer that merges different scan ops.
......@@ -2135,14 +2139,14 @@ def scan_merge_inouts(node):
return na.outer_outputs
class PushOutDot1(gof.Optimizer):
class PushOutDot1(gof.GlobalOptimizer):
"""
Graph optimizer for Scan(makes it run inplace).
"""
def __init__(self):
Optimizer.__init__(self)
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
......
......@@ -147,9 +147,9 @@ from theano.compile.mode import optdb
from theano.gof import (
Apply,
EquilibriumOptimizer,
GlobalOptimizer,
InconsistencyError,
Op,
Optimizer,
ReplacementDidNotRemoveError,
SequenceDB,
local_optimizer,
......@@ -1449,11 +1449,11 @@ def _gemm_from_node2(node):
return None, t1 - t0, 0, 0
class GemmOptimizer(Optimizer):
class GemmOptimizer(GlobalOptimizer):
"""Graph optimizer for inserting Gemm operations."""
def __init__(self):
Optimizer.__init__(self)
super().__init__()
self.warned = False
def add_requirements(self, fgraph):
......
......@@ -33,7 +33,7 @@ from theano.gof import (
)
from theano.gof.op import Op
from theano.gof.opt import (
Optimizer,
GlobalOptimizer,
copy_stack_trace,
in2out,
local_optimizer,
......@@ -214,7 +214,7 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval
class InplaceElemwiseOptimizer(Optimizer):
class InplaceElemwiseOptimizer(GlobalOptimizer):
"""
We parametrise it to make it work for Elemwise and GpuElemwise op.
"""
......@@ -1664,7 +1664,7 @@ class ShapeFeature:
return True
class ShapeOptimizer(Optimizer):
class ShapeOptimizer(GlobalOptimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature."""
def add_requirements(self, fgraph):
......@@ -1674,7 +1674,7 @@ class ShapeOptimizer(Optimizer):
pass
class UnShapeOptimizer(Optimizer):
class UnShapeOptimizer(GlobalOptimizer):
"""Optimizer remove ShapeFeature as an fgraph feature."""
def apply(self, fgraph):
......@@ -7729,11 +7729,11 @@ def elemwise_max_input_fct(node):
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
class FusionOptimizer(Optimizer):
class FusionOptimizer(GlobalOptimizer):
"""Graph optimizer for Fusion of elemwise operations."""
def __init__(self, local_optimizer):
Optimizer.__init__(self)
super().__init__()
self.optimizer = local_optimizer
def add_requirements(self, fgraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论