提交 55ca059a authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Change Optimizer to GlobalOptimizer

上级 48ac71e3
...@@ -57,7 +57,7 @@ Global optimization ...@@ -57,7 +57,7 @@ Global optimization
A global optimization (or optimizer) is an object which defines the following A global optimization (or optimizer) is an object which defines the following
methods: methods:
.. class:: Optimizer .. class:: GlobalOptimizer
.. method:: apply(fgraph) .. method:: apply(fgraph)
...@@ -75,7 +75,7 @@ methods: ...@@ -75,7 +75,7 @@ methods:
This is the interface function called by Theano. This is the interface function called by Theano.
*Default:* this is defined by Optimizer as ``add_requirement(fgraph); *Default:* this is defined by GlobalOptimizer as ``add_requirement(fgraph);
apply(fgraph)``. apply(fgraph)``.
See the section about :class:`FunctionGraph` to understand how to define these See the section about :class:`FunctionGraph` to understand how to define these
...@@ -125,7 +125,7 @@ simplification described above: ...@@ -125,7 +125,7 @@ simplification described above:
from theano import gof from theano import gof
from theano.gof import toolbox from theano.gof import toolbox
class Simplify(gof.Optimizer): class Simplify(gof.GlobalOptimizer):
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
......
...@@ -5,7 +5,7 @@ from theano.gof.optdb import DB, opt ...@@ -5,7 +5,7 @@ from theano.gof.optdb import DB, opt
class TestDB: class TestDB:
def test_name_clashes(self): def test_name_clashes(self):
class Opt(opt.Optimizer): # inheritance buys __hash__ class Opt(opt.GlobalOptimizer): # inheritance buys __hash__
name = "blah" name = "blah"
db = DB() db = DB()
......
...@@ -93,13 +93,13 @@ predefined_optimizers = { ...@@ -93,13 +93,13 @@ predefined_optimizers = {
def register_optimizer(name, opt): def register_optimizer(name, opt):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`.""" """Add a `GlobalOptimizer` which can be referred to by `name` in `Mode`."""
if name in predefined_optimizers: if name in predefined_optimizers:
raise ValueError(f"Optimizer name already taken: {name}") raise ValueError(f"Optimizer name already taken: {name}")
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
class AddDestroyHandler(gof.Optimizer): class AddDestroyHandler(gof.GlobalOptimizer):
""" """
This optimizer performs two important functions: This optimizer performs two important functions:
...@@ -134,7 +134,7 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -134,7 +134,7 @@ class AddDestroyHandler(gof.Optimizer):
fgraph.attach_feature(gof.DestroyHandler()) fgraph.attach_feature(gof.DestroyHandler())
class AddFeatureOptimizer(gof.Optimizer): class AddFeatureOptimizer(gof.GlobalOptimizer):
""" """
This optimizer adds a provided feature to the function graph. This optimizer adds a provided feature to the function graph.
""" """
...@@ -147,7 +147,7 @@ class AddFeatureOptimizer(gof.Optimizer): ...@@ -147,7 +147,7 @@ class AddFeatureOptimizer(gof.Optimizer):
fgraph.attach_feature(self.feature) fgraph.attach_feature(self.feature)
class PrintCurrentFunctionGraph(gof.Optimizer): class PrintCurrentFunctionGraph(gof.GlobalOptimizer):
""" """
This optimizer is for debugging. This optimizer is for debugging.
......
...@@ -24,6 +24,7 @@ from theano.gof.op import ( ...@@ -24,6 +24,7 @@ from theano.gof.op import (
from theano.gof.opt import ( from theano.gof.opt import (
CheckStackTraceOptimization, CheckStackTraceOptimization,
EquilibriumOptimizer, EquilibriumOptimizer,
GlobalOptimizer,
LocalOptGroup, LocalOptGroup,
LocalOptimizer, LocalOptimizer,
MergeOptimizer, MergeOptimizer,
...@@ -31,7 +32,6 @@ from theano.gof.opt import ( ...@@ -31,7 +32,6 @@ from theano.gof.opt import (
OpKeyOptimizer, OpKeyOptimizer,
OpRemove, OpRemove,
OpSub, OpSub,
Optimizer,
PatternSub, PatternSub,
SeqOptimizer, SeqOptimizer,
TopoOptimizer, TopoOptimizer,
......
...@@ -43,10 +43,10 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError): ...@@ -43,10 +43,10 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
""" """
class Optimizer: class GlobalOptimizer:
""" """
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it. A L{GlobalOptimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}. of transformation you could apply to an L{FunctionGraph}.
...@@ -73,7 +73,7 @@ class Optimizer: ...@@ -73,7 +73,7 @@ class Optimizer:
Applies the optimization to the provided L{FunctionGraph}. It may Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an L{GlobalOptimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method. L{InstanceFinder}, it can do so in its L{add_requirements} method.
""" """
...@@ -125,11 +125,8 @@ class Optimizer: ...@@ -125,11 +125,8 @@ class Optimizer:
) )
class FromFunctionOptimizer(Optimizer): class FromFunctionOptimizer(GlobalOptimizer):
""" """A `GlobalOptimizer` constructed from a given function."""
WRITEME
"""
def __init__(self, fn, requirements=()): def __init__(self, fn, requirements=()):
self.apply = fn self.apply = fn
...@@ -171,14 +168,8 @@ def inplace_optimizer(f): ...@@ -171,14 +168,8 @@ def inplace_optimizer(f):
return rval return rval
class SeqOptimizer(Optimizer, list): class SeqOptimizer(GlobalOptimizer, list):
# inherit from Optimizer first to get Optimizer.__hash__ """A `GlobalOptimizer` that applies a list of optimizers sequentially."""
"""
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
@staticmethod @staticmethod
def warn(exc, self, optimizer): def warn(exc, self, optimizer):
...@@ -214,7 +205,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -214,7 +205,7 @@ class SeqOptimizer(Optimizer, list):
def apply(self, fgraph): def apply(self, fgraph):
""" """
Applies each L{Optimizer} in self in turn. Applies each L{GlobalOptimizer} in self in turn.
""" """
l = [] l = []
...@@ -823,7 +814,7 @@ class MergeFeature: ...@@ -823,7 +814,7 @@ class MergeFeature:
return new_inputs return new_inputs
class MergeOptimizer(Optimizer): class MergeOptimizer(GlobalOptimizer):
""" """
Merges parts of the graph that are identical and redundant. Merges parts of the graph that are identical and redundant.
...@@ -1945,7 +1936,7 @@ class Updater: ...@@ -1945,7 +1936,7 @@ class Updater:
self.chin = None self.chin = None
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(GlobalOptimizer):
""" """
Abstract class. Abstract class.
...@@ -2835,7 +2826,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2835,7 +2826,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+ list(opt.final_optimizers) + list(opt.final_optimizers)
+ list(opt.cleanup_optimizers) + list(opt.cleanup_optimizers)
) )
if o.print_profile.__code__ is not Optimizer.print_profile.__code__ if o.print_profile.__code__ is not GlobalOptimizer.print_profile.__code__
] ]
if not gf_opts: if not gf_opts:
return return
...@@ -3310,7 +3301,7 @@ class CheckStrackTraceFeature: ...@@ -3310,7 +3301,7 @@ class CheckStrackTraceFeature:
) )
class CheckStackTraceOptimization(Optimizer): class CheckStackTraceOptimization(GlobalOptimizer):
"""Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature.""" """Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
......
...@@ -43,10 +43,10 @@ class DB: ...@@ -43,10 +43,10 @@ class DB:
tags specified will enable that optimization. tags specified will enable that optimization.
""" """
# N.B. obj is not an instance of class Optimizer. # N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example, # It is an instance of a DB.In the tests for example,
# this is not always the case. # this is not always the case.
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)): if not isinstance(obj, (DB, opt.GlobalOptimizer, opt.LocalOptimizer)):
raise TypeError("Object cannot be registered in OptDB", obj) raise TypeError("Object cannot be registered in OptDB", obj)
if name in self.__db__: if name in self.__db__:
raise ValueError( raise ValueError(
...@@ -285,8 +285,8 @@ class EquilibriumDB(DB): ...@@ -285,8 +285,8 @@ class EquilibriumDB(DB):
Notes Notes
----- -----
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer We can use `LocalOptimizer` and `GlobalOptimizer` since `EquilibriumOptimizer`
suppor both. supports both.
It is probably not a good idea to have ignore_newtrees=False and It is probably not a good idea to have ignore_newtrees=False and
tracks_on_change_inputs=True tracks_on_change_inputs=True
...@@ -473,7 +473,7 @@ class LocalGroupDB(DB): ...@@ -473,7 +473,7 @@ class LocalGroupDB(DB):
class TopoDB(DB): class TopoDB(DB):
""" """
Generate a Global Optimizer of type TopoOptimizer. Generate a `GlobalOptimizer` of type TopoOptimizer.
""" """
......
import theano import theano
from theano.compile import optdb from theano.compile import optdb
from theano.compile.ops import shape_i_op from theano.compile.ops import shape_i_op
from theano.gof.opt import Optimizer, inherit_stack_trace, local_optimizer from theano.gof.opt import GlobalOptimizer, inherit_stack_trace, local_optimizer
from theano.gpuarray.basic_ops import ( from theano.gpuarray.basic_ops import (
GpuAllocEmpty, GpuAllocEmpty,
GpuArrayType, GpuArrayType,
...@@ -817,7 +817,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs): ...@@ -817,7 +817,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs):
return [as_gpuarray_variable(arg.astype("int64"), ctx_name)] return [as_gpuarray_variable(arg.astype("int64"), ctx_name)]
class NoCuDNNRaise(Optimizer): class NoCuDNNRaise(GlobalOptimizer):
def apply(self, fgraph): def apply(self, fgraph):
""" """
Raise a error if cudnn can't be used. Raise a error if cudnn can't be used.
......
...@@ -15,7 +15,7 @@ from theano import config, gof, scalar, tensor ...@@ -15,7 +15,7 @@ from theano import config, gof, scalar, tensor
from theano.breakpoint import PdbBreakpoint from theano.breakpoint import PdbBreakpoint
from theano.compile import optdb from theano.compile import optdb
from theano.compile.ops import shape_i from theano.compile.ops import shape_i
from theano.gof import Optimizer, graph, local_optimizer, toolbox from theano.gof import GlobalOptimizer, graph, local_optimizer, toolbox
from theano.gof.opt import LocalMetaOptimizer, copy_stack_trace, inherit_stack_trace from theano.gof.opt import LocalMetaOptimizer, copy_stack_trace, inherit_stack_trace
from theano.gpuarray.basic_ops import ( from theano.gpuarray.basic_ops import (
GpuAlloc, GpuAlloc,
...@@ -210,7 +210,7 @@ gpu_neg = GpuElemwise(neg) ...@@ -210,7 +210,7 @@ gpu_neg = GpuElemwise(neg)
gpu_true_div = GpuElemwise(true_div) gpu_true_div = GpuElemwise(true_div)
class InputToGpuOptimizer(Optimizer): class InputToGpuOptimizer(GlobalOptimizer):
""" """
Transfer the input to the gpu to start the rolling wave. Transfer the input to the gpu to start the rolling wave.
...@@ -260,7 +260,7 @@ gpu_seqopt.register( ...@@ -260,7 +260,7 @@ gpu_seqopt.register(
) )
class GraphToGPU(Optimizer): class GraphToGPU(GlobalOptimizer):
""" """
Transfer the graph as a whole to GPU instead of transferring node by node. Transfer the graph as a whole to GPU instead of transferring node by node.
......
...@@ -592,7 +592,7 @@ def cond_merge_ifs_false(node): ...@@ -592,7 +592,7 @@ def cond_merge_ifs_false(node):
return op(*old_ins, **dict(return_list=True)) return op(*old_ins, **dict(return_list=True))
class CondMerge(gof.Optimizer): class CondMerge(gof.GlobalOptimizer):
""" Graph Optimizer that merges different cond ops """ """ Graph Optimizer that merges different cond ops """
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
......
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import theano.tensor import theano.tensor
from theano import tensor from theano import tensor
from theano.gof import Apply, Op, local_optimizer from theano.gof import Apply, Op, local_optimizer
from theano.gof.opt import Optimizer from theano.gof.opt import GlobalOptimizer
from theano.tensor import DimShuffle, Dot from theano.tensor import DimShuffle, Dot
from theano.tensor.blas import Dot22 from theano.tensor.blas import Dot22
from theano.tensor.nlinalg import ( from theano.tensor.nlinalg import (
...@@ -171,13 +171,13 @@ class HintsFeature: ...@@ -171,13 +171,13 @@ class HintsFeature:
# 2) we are putting things back after a failed transaction. # 2) we are putting things back after a failed transaction.
class HintsOptimizer(Optimizer): class HintsOptimizer(GlobalOptimizer):
""" """
Optimizer that serves to add HintsFeature as an fgraph feature. Optimizer that serves to add HintsFeature as an fgraph feature.
""" """
def __init__(self): def __init__(self):
Optimizer.__init__(self) super().__init__()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(HintsFeature()) fgraph.attach_feature(HintsFeature())
......
...@@ -63,7 +63,11 @@ from theano.compile import optdb ...@@ -63,7 +63,11 @@ from theano.compile import optdb
from theano.compile.function.types import deep_copy_op from theano.compile.function.types import deep_copy_op
from theano.gof import DestroyHandler, InconsistencyError, toolbox from theano.gof import DestroyHandler, InconsistencyError, toolbox
from theano.gof.graph import equal_computations from theano.gof.graph import equal_computations
from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer from theano.gof.opt import (
GlobalOptimizer,
pre_constant_merge,
pre_greedy_local_optimizer,
)
from theano.scan.op import Scan from theano.scan.op import Scan
from theano.scan.utils import ( from theano.scan.utils import (
clone, clone,
...@@ -224,14 +228,14 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -224,14 +228,14 @@ def remove_constants_and_unused_inputs_scan(node):
# This is a global opt for historical reason # This is a global opt for historical reason
# It should be possible to change it to a local opt. # It should be possible to change it to a local opt.
class PushOutNonSeqScan(gof.Optimizer): class PushOutNonSeqScan(gof.GlobalOptimizer):
""" """
A global optimizer for pushing out the variables inside the scan that depend A global optimizer for pushing out the variables inside the scan that depend
only on non-sequences. only on non-sequences.
""" """
def __init__(self): def __init__(self):
gof.Optimizer.__init__(self) super().__init__()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate()) fgraph.attach_feature(gof.toolbox.ReplaceValidate())
...@@ -440,14 +444,14 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -440,14 +444,14 @@ class PushOutNonSeqScan(gof.Optimizer):
# This is a global opt for historical reason # This is a global opt for historical reason
# It should be possible to change it to a local opt. # It should be possible to change it to a local opt.
class PushOutSeqScan(gof.Optimizer): class PushOutSeqScan(gof.GlobalOptimizer):
""" """
A global optimizer for pushing out the variables inside the A global optimizer for pushing out the variables inside the
scan that depend only on constants and sequences. scan that depend only on constants and sequences.
""" """
def __init__(self): def __init__(self):
gof.Optimizer.__init__(self) super().__init__()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate()) fgraph.attach_feature(gof.toolbox.ReplaceValidate())
...@@ -696,14 +700,14 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -696,14 +700,14 @@ class PushOutSeqScan(gof.Optimizer):
return False return False
class PushOutScanOutput(gof.Optimizer): class PushOutScanOutput(gof.GlobalOptimizer):
""" """
This is an optimization that can push operations performed This is an optimization that can push operations performed
at the end of the inner graph of scan to outside of scan. at the end of the inner graph of scan to outside of scan.
""" """
def __init__(self): def __init__(self):
gof.Optimizer.__init__(self) super().__init__()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate()) fgraph.attach_feature(gof.toolbox.ReplaceValidate())
...@@ -958,14 +962,14 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -958,14 +962,14 @@ class PushOutScanOutput(gof.Optimizer):
return new_scan_node return new_scan_node
class ScanInplaceOptimizer(Optimizer): class ScanInplaceOptimizer(GlobalOptimizer):
""" """
Graph optimizer for Scan (makes it run inplace). Graph optimizer for Scan (makes it run inplace).
""" """
def __init__(self, typeInfer=None, gpua_flag=False): def __init__(self, typeInfer=None, gpua_flag=False):
Optimizer.__init__(self) super().__init__()
self.typeInfer = typeInfer self.typeInfer = typeInfer
self.gpua_flag = gpua_flag self.gpua_flag = gpua_flag
...@@ -1124,14 +1128,14 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1124,14 +1128,14 @@ class ScanInplaceOptimizer(Optimizer):
node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops) node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops)
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.GlobalOptimizer):
""" """
Graph Optimizer that reduces scan memory consumption. Graph Optimizer that reduces scan memory consumption.
""" """
def __init__(self): def __init__(self):
gof.Optimizer.__init__(self) super().__init__()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate()) fgraph.attach_feature(gof.toolbox.ReplaceValidate())
...@@ -1680,7 +1684,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1680,7 +1684,7 @@ class ScanSaveMem(gof.Optimizer):
self.process_node(fgraph, node) self.process_node(fgraph, node)
class ScanMerge(gof.Optimizer): class ScanMerge(gof.GlobalOptimizer):
""" """
Graph Optimizer that merges different scan ops. Graph Optimizer that merges different scan ops.
...@@ -2135,14 +2139,14 @@ def scan_merge_inouts(node): ...@@ -2135,14 +2139,14 @@ def scan_merge_inouts(node):
return na.outer_outputs return na.outer_outputs
class PushOutDot1(gof.Optimizer): class PushOutDot1(gof.GlobalOptimizer):
""" """
Graph optimizer for Scan(makes it run inplace). Graph optimizer for Scan(makes it run inplace).
""" """
def __init__(self): def __init__(self):
Optimizer.__init__(self) super().__init__()
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
......
...@@ -147,9 +147,9 @@ from theano.compile.mode import optdb ...@@ -147,9 +147,9 @@ from theano.compile.mode import optdb
from theano.gof import ( from theano.gof import (
Apply, Apply,
EquilibriumOptimizer, EquilibriumOptimizer,
GlobalOptimizer,
InconsistencyError, InconsistencyError,
Op, Op,
Optimizer,
ReplacementDidNotRemoveError, ReplacementDidNotRemoveError,
SequenceDB, SequenceDB,
local_optimizer, local_optimizer,
...@@ -1449,11 +1449,11 @@ def _gemm_from_node2(node): ...@@ -1449,11 +1449,11 @@ def _gemm_from_node2(node):
return None, t1 - t0, 0, 0 return None, t1 - t0, 0, 0
class GemmOptimizer(Optimizer): class GemmOptimizer(GlobalOptimizer):
"""Graph optimizer for inserting Gemm operations.""" """Graph optimizer for inserting Gemm operations."""
def __init__(self): def __init__(self):
Optimizer.__init__(self) super().__init__()
self.warned = False self.warned = False
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
......
...@@ -33,7 +33,7 @@ from theano.gof import ( ...@@ -33,7 +33,7 @@ from theano.gof import (
) )
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.opt import ( from theano.gof.opt import (
Optimizer, GlobalOptimizer,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
local_optimizer, local_optimizer,
...@@ -214,7 +214,7 @@ def broadcast_like(value, template, fgraph, dtype=None): ...@@ -214,7 +214,7 @@ def broadcast_like(value, template, fgraph, dtype=None):
return rval return rval
class InplaceElemwiseOptimizer(Optimizer): class InplaceElemwiseOptimizer(GlobalOptimizer):
""" """
We parametrise it to make it work for Elemwise and GpuElemwise op. We parametrise it to make it work for Elemwise and GpuElemwise op.
""" """
...@@ -1664,7 +1664,7 @@ class ShapeFeature: ...@@ -1664,7 +1664,7 @@ class ShapeFeature:
return True return True
class ShapeOptimizer(Optimizer): class ShapeOptimizer(GlobalOptimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature.""" """Optimizer that serves to add ShapeFeature as an fgraph feature."""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
...@@ -1674,7 +1674,7 @@ class ShapeOptimizer(Optimizer): ...@@ -1674,7 +1674,7 @@ class ShapeOptimizer(Optimizer):
pass pass
class UnShapeOptimizer(Optimizer): class UnShapeOptimizer(GlobalOptimizer):
"""Optimizer remove ShapeFeature as an fgraph feature.""" """Optimizer remove ShapeFeature as an fgraph feature."""
def apply(self, fgraph): def apply(self, fgraph):
...@@ -7729,11 +7729,11 @@ def elemwise_max_input_fct(node): ...@@ -7729,11 +7729,11 @@ def elemwise_max_input_fct(node):
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
class FusionOptimizer(Optimizer): class FusionOptimizer(GlobalOptimizer):
"""Graph optimizer for Fusion of elemwise operations.""" """Graph optimizer for Fusion of elemwise operations."""
def __init__(self, local_optimizer): def __init__(self, local_optimizer):
Optimizer.__init__(self) super().__init__()
self.optimizer = local_optimizer self.optimizer = local_optimizer
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论