Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
55ca059a
提交
55ca059a
authored
11月 13, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change Optimizer to GlobalOptimizer
上级
48ac71e3
显示空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
62 行增加
和
67 行删除
+62
-67
optimization.txt
doc/extending/optimization.txt
+3
-3
test_optdb.py
tests/gof/test_optdb.py
+1
-1
mode.py
theano/compile/mode.py
+4
-4
__init__.py
theano/gof/__init__.py
+1
-1
opt.py
theano/gof/opt.py
+12
-21
optdb.py
theano/gof/optdb.py
+5
-5
dnn_opt.py
theano/gpuarray/dnn_opt.py
+2
-2
opt.py
theano/gpuarray/opt.py
+3
-3
ifelse.py
theano/ifelse.py
+1
-1
ops.py
theano/sandbox/linalg/ops.py
+3
-3
opt.py
theano/scan/opt.py
+18
-14
blas.py
theano/tensor/blas.py
+3
-3
opt.py
theano/tensor/opt.py
+6
-6
没有找到文件。
doc/extending/optimization.txt
浏览文件 @
55ca059a
...
@@ -57,7 +57,7 @@ Global optimization
...
@@ -57,7 +57,7 @@ Global optimization
A global optimization (or optimizer) is an object which defines the following
A global optimization (or optimizer) is an object which defines the following
methods:
methods:
.. class:: Optimizer
.. class::
Global
Optimizer
.. method:: apply(fgraph)
.. method:: apply(fgraph)
...
@@ -75,7 +75,7 @@ methods:
...
@@ -75,7 +75,7 @@ methods:
This is the interface function called by Theano.
This is the interface function called by Theano.
*Default:* this is defined by Optimizer as ``add_requirement(fgraph);
*Default:* this is defined by
Global
Optimizer as ``add_requirement(fgraph);
apply(fgraph)``.
apply(fgraph)``.
See the section about :class:`FunctionGraph` to understand how to define these
See the section about :class:`FunctionGraph` to understand how to define these
...
@@ -125,7 +125,7 @@ simplification described above:
...
@@ -125,7 +125,7 @@ simplification described above:
from theano import gof
from theano import gof
from theano.gof import toolbox
from theano.gof import toolbox
class Simplify(gof.Optimizer):
class Simplify(gof.
Global
Optimizer):
def add_requirements(self, fgraph):
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
def apply(self, fgraph):
...
...
tests/gof/test_optdb.py
浏览文件 @
55ca059a
...
@@ -5,7 +5,7 @@ from theano.gof.optdb import DB, opt
...
@@ -5,7 +5,7 @@ from theano.gof.optdb import DB, opt
class
TestDB
:
class
TestDB
:
def
test_name_clashes
(
self
):
def
test_name_clashes
(
self
):
class
Opt
(
opt
.
Optimizer
):
# inheritance buys __hash__
class
Opt
(
opt
.
Global
Optimizer
):
# inheritance buys __hash__
name
=
"blah"
name
=
"blah"
db
=
DB
()
db
=
DB
()
...
...
theano/compile/mode.py
浏览文件 @
55ca059a
...
@@ -93,13 +93,13 @@ predefined_optimizers = {
...
@@ -93,13 +93,13 @@ predefined_optimizers = {
def
register_optimizer
(
name
,
opt
):
def
register_optimizer
(
name
,
opt
):
"""Add a `Optimizer` which can be referred to by `name` in `Mode`."""
"""Add a `
Global
Optimizer` which can be referred to by `name` in `Mode`."""
if
name
in
predefined_optimizers
:
if
name
in
predefined_optimizers
:
raise
ValueError
(
f
"Optimizer name already taken: {name}"
)
raise
ValueError
(
f
"Optimizer name already taken: {name}"
)
predefined_optimizers
[
name
]
=
opt
predefined_optimizers
[
name
]
=
opt
class
AddDestroyHandler
(
gof
.
Optimizer
):
class
AddDestroyHandler
(
gof
.
Global
Optimizer
):
"""
"""
This optimizer performs two important functions:
This optimizer performs two important functions:
...
@@ -134,7 +134,7 @@ class AddDestroyHandler(gof.Optimizer):
...
@@ -134,7 +134,7 @@ class AddDestroyHandler(gof.Optimizer):
fgraph
.
attach_feature
(
gof
.
DestroyHandler
())
fgraph
.
attach_feature
(
gof
.
DestroyHandler
())
class
AddFeatureOptimizer
(
gof
.
Optimizer
):
class
AddFeatureOptimizer
(
gof
.
Global
Optimizer
):
"""
"""
This optimizer adds a provided feature to the function graph.
This optimizer adds a provided feature to the function graph.
"""
"""
...
@@ -147,7 +147,7 @@ class AddFeatureOptimizer(gof.Optimizer):
...
@@ -147,7 +147,7 @@ class AddFeatureOptimizer(gof.Optimizer):
fgraph
.
attach_feature
(
self
.
feature
)
fgraph
.
attach_feature
(
self
.
feature
)
class
PrintCurrentFunctionGraph
(
gof
.
Optimizer
):
class
PrintCurrentFunctionGraph
(
gof
.
Global
Optimizer
):
"""
"""
This optimizer is for debugging.
This optimizer is for debugging.
...
...
theano/gof/__init__.py
浏览文件 @
55ca059a
...
@@ -24,6 +24,7 @@ from theano.gof.op import (
...
@@ -24,6 +24,7 @@ from theano.gof.op import (
from
theano.gof.opt
import
(
from
theano.gof.opt
import
(
CheckStackTraceOptimization
,
CheckStackTraceOptimization
,
EquilibriumOptimizer
,
EquilibriumOptimizer
,
GlobalOptimizer
,
LocalOptGroup
,
LocalOptGroup
,
LocalOptimizer
,
LocalOptimizer
,
MergeOptimizer
,
MergeOptimizer
,
...
@@ -31,7 +32,6 @@ from theano.gof.opt import (
...
@@ -31,7 +32,6 @@ from theano.gof.opt import (
OpKeyOptimizer
,
OpKeyOptimizer
,
OpRemove
,
OpRemove
,
OpSub
,
OpSub
,
Optimizer
,
PatternSub
,
PatternSub
,
SeqOptimizer
,
SeqOptimizer
,
TopoOptimizer
,
TopoOptimizer
,
...
...
theano/gof/opt.py
浏览文件 @
55ca059a
...
@@ -43,10 +43,10 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
...
@@ -43,10 +43,10 @@ class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""
"""
class
Optimizer
:
class
Global
Optimizer
:
"""
"""
A
n L{
Optimizer} can be applied to an L{FunctionGraph} to transform it.
A
L{Global
Optimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind
It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}.
of transformation you could apply to an L{FunctionGraph}.
...
@@ -73,7 +73,7 @@ class Optimizer:
...
@@ -73,7 +73,7 @@ class Optimizer:
Applies the optimization to the provided L{FunctionGraph}. It may
Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the
use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an
L{
Global
Optimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method.
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
"""
...
@@ -125,11 +125,8 @@ class Optimizer:
...
@@ -125,11 +125,8 @@ class Optimizer:
)
)
class
FromFunctionOptimizer
(
Optimizer
):
class
FromFunctionOptimizer
(
GlobalOptimizer
):
"""
"""A `GlobalOptimizer` constructed from a given function."""
WRITEME
"""
def
__init__
(
self
,
fn
,
requirements
=
()):
def
__init__
(
self
,
fn
,
requirements
=
()):
self
.
apply
=
fn
self
.
apply
=
fn
...
@@ -171,14 +168,8 @@ def inplace_optimizer(f):
...
@@ -171,14 +168,8 @@ def inplace_optimizer(f):
return
rval
return
rval
class
SeqOptimizer
(
Optimizer
,
list
):
class
SeqOptimizer
(
GlobalOptimizer
,
list
):
# inherit from Optimizer first to get Optimizer.__hash__
"""A `GlobalOptimizer` that applies a list of optimizers sequentially."""
"""
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
@staticmethod
@staticmethod
def
warn
(
exc
,
self
,
optimizer
):
def
warn
(
exc
,
self
,
optimizer
):
...
@@ -214,7 +205,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -214,7 +205,7 @@ class SeqOptimizer(Optimizer, list):
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
"""
"""
Applies each L{Optimizer} in self in turn.
Applies each L{
Global
Optimizer} in self in turn.
"""
"""
l
=
[]
l
=
[]
...
@@ -823,7 +814,7 @@ class MergeFeature:
...
@@ -823,7 +814,7 @@ class MergeFeature:
return
new_inputs
return
new_inputs
class
MergeOptimizer
(
Optimizer
):
class
MergeOptimizer
(
Global
Optimizer
):
"""
"""
Merges parts of the graph that are identical and redundant.
Merges parts of the graph that are identical and redundant.
...
@@ -1945,7 +1936,7 @@ class Updater:
...
@@ -1945,7 +1936,7 @@ class Updater:
self
.
chin
=
None
self
.
chin
=
None
class
NavigatorOptimizer
(
Optimizer
):
class
NavigatorOptimizer
(
Global
Optimizer
):
"""
"""
Abstract class.
Abstract class.
...
@@ -2835,7 +2826,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2835,7 +2826,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+
list
(
opt
.
final_optimizers
)
+
list
(
opt
.
final_optimizers
)
+
list
(
opt
.
cleanup_optimizers
)
+
list
(
opt
.
cleanup_optimizers
)
)
)
if
o
.
print_profile
.
__code__
is
not
Optimizer
.
print_profile
.
__code__
if
o
.
print_profile
.
__code__
is
not
Global
Optimizer
.
print_profile
.
__code__
]
]
if
not
gf_opts
:
if
not
gf_opts
:
return
return
...
@@ -3310,7 +3301,7 @@ class CheckStrackTraceFeature:
...
@@ -3310,7 +3301,7 @@ class CheckStrackTraceFeature:
)
)
class
CheckStackTraceOptimization
(
Optimizer
):
class
CheckStackTraceOptimization
(
Global
Optimizer
):
"""Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
"""Optimizer that serves to add CheckStackTraceOptimization as an fgraph feature."""
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
...
...
theano/gof/optdb.py
浏览文件 @
55ca059a
...
@@ -43,10 +43,10 @@ class DB:
...
@@ -43,10 +43,10 @@ class DB:
tags specified will enable that optimization.
tags specified will enable that optimization.
"""
"""
# N.B. obj is not an instance of class
Optimizer
.
# N.B. obj is not an instance of class
`GlobalOptimizer`
.
# It is an instance of a DB.In the tests for example,
# It is an instance of a DB.In the tests for example,
# this is not always the case.
# this is not always the case.
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
,
opt
.
LocalOptimizer
)):
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Global
Optimizer
,
opt
.
LocalOptimizer
)):
raise
TypeError
(
"Object cannot be registered in OptDB"
,
obj
)
raise
TypeError
(
"Object cannot be registered in OptDB"
,
obj
)
if
name
in
self
.
__db__
:
if
name
in
self
.
__db__
:
raise
ValueError
(
raise
ValueError
(
...
@@ -285,8 +285,8 @@ class EquilibriumDB(DB):
...
@@ -285,8 +285,8 @@ class EquilibriumDB(DB):
Notes
Notes
-----
-----
We can
put LocalOptimizer and Optimizer as EquilibriumOptimizer
We can
use `LocalOptimizer` and `GlobalOptimizer` since `EquilibriumOptimizer`
suppor both.
suppor
ts
both.
It is probably not a good idea to have ignore_newtrees=False and
It is probably not a good idea to have ignore_newtrees=False and
tracks_on_change_inputs=True
tracks_on_change_inputs=True
...
@@ -473,7 +473,7 @@ class LocalGroupDB(DB):
...
@@ -473,7 +473,7 @@ class LocalGroupDB(DB):
class
TopoDB
(
DB
):
class
TopoDB
(
DB
):
"""
"""
Generate a
Global Optimizer
of type TopoOptimizer.
Generate a
`GlobalOptimizer`
of type TopoOptimizer.
"""
"""
...
...
theano/gpuarray/dnn_opt.py
浏览文件 @
55ca059a
import
theano
import
theano
from
theano.compile
import
optdb
from
theano.compile
import
optdb
from
theano.compile.ops
import
shape_i_op
from
theano.compile.ops
import
shape_i_op
from
theano.gof.opt
import
Optimizer
,
inherit_stack_trace
,
local_optimizer
from
theano.gof.opt
import
Global
Optimizer
,
inherit_stack_trace
,
local_optimizer
from
theano.gpuarray.basic_ops
import
(
from
theano.gpuarray.basic_ops
import
(
GpuAllocEmpty
,
GpuAllocEmpty
,
GpuArrayType
,
GpuArrayType
,
...
@@ -817,7 +817,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs):
...
@@ -817,7 +817,7 @@ def local_dnn_argmax(op, ctx_name, inputs, outputs):
return
[
as_gpuarray_variable
(
arg
.
astype
(
"int64"
),
ctx_name
)]
return
[
as_gpuarray_variable
(
arg
.
astype
(
"int64"
),
ctx_name
)]
class
NoCuDNNRaise
(
Optimizer
):
class
NoCuDNNRaise
(
Global
Optimizer
):
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
"""
"""
Raise a error if cudnn can't be used.
Raise a error if cudnn can't be used.
...
...
theano/gpuarray/opt.py
浏览文件 @
55ca059a
...
@@ -15,7 +15,7 @@ from theano import config, gof, scalar, tensor
...
@@ -15,7 +15,7 @@ from theano import config, gof, scalar, tensor
from
theano.breakpoint
import
PdbBreakpoint
from
theano.breakpoint
import
PdbBreakpoint
from
theano.compile
import
optdb
from
theano.compile
import
optdb
from
theano.compile.ops
import
shape_i
from
theano.compile.ops
import
shape_i
from
theano.gof
import
Optimizer
,
graph
,
local_optimizer
,
toolbox
from
theano.gof
import
Global
Optimizer
,
graph
,
local_optimizer
,
toolbox
from
theano.gof.opt
import
LocalMetaOptimizer
,
copy_stack_trace
,
inherit_stack_trace
from
theano.gof.opt
import
LocalMetaOptimizer
,
copy_stack_trace
,
inherit_stack_trace
from
theano.gpuarray.basic_ops
import
(
from
theano.gpuarray.basic_ops
import
(
GpuAlloc
,
GpuAlloc
,
...
@@ -210,7 +210,7 @@ gpu_neg = GpuElemwise(neg)
...
@@ -210,7 +210,7 @@ gpu_neg = GpuElemwise(neg)
gpu_true_div
=
GpuElemwise
(
true_div
)
gpu_true_div
=
GpuElemwise
(
true_div
)
class
InputToGpuOptimizer
(
Optimizer
):
class
InputToGpuOptimizer
(
Global
Optimizer
):
"""
"""
Transfer the input to the gpu to start the rolling wave.
Transfer the input to the gpu to start the rolling wave.
...
@@ -260,7 +260,7 @@ gpu_seqopt.register(
...
@@ -260,7 +260,7 @@ gpu_seqopt.register(
)
)
class
GraphToGPU
(
Optimizer
):
class
GraphToGPU
(
Global
Optimizer
):
"""
"""
Transfer the graph as a whole to GPU instead of transferring node by node.
Transfer the graph as a whole to GPU instead of transferring node by node.
...
...
theano/ifelse.py
浏览文件 @
55ca059a
...
@@ -592,7 +592,7 @@ def cond_merge_ifs_false(node):
...
@@ -592,7 +592,7 @@ def cond_merge_ifs_false(node):
return
op
(
*
old_ins
,
**
dict
(
return_list
=
True
))
return
op
(
*
old_ins
,
**
dict
(
return_list
=
True
))
class
CondMerge
(
gof
.
Optimizer
):
class
CondMerge
(
gof
.
Global
Optimizer
):
""" Graph Optimizer that merges different cond ops """
""" Graph Optimizer that merges different cond ops """
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
55ca059a
...
@@ -3,7 +3,7 @@ import logging
...
@@ -3,7 +3,7 @@ import logging
import
theano.tensor
import
theano.tensor
from
theano
import
tensor
from
theano
import
tensor
from
theano.gof
import
Apply
,
Op
,
local_optimizer
from
theano.gof
import
Apply
,
Op
,
local_optimizer
from
theano.gof.opt
import
Optimizer
from
theano.gof.opt
import
Global
Optimizer
from
theano.tensor
import
DimShuffle
,
Dot
from
theano.tensor
import
DimShuffle
,
Dot
from
theano.tensor.blas
import
Dot22
from
theano.tensor.blas
import
Dot22
from
theano.tensor.nlinalg
import
(
from
theano.tensor.nlinalg
import
(
...
@@ -171,13 +171,13 @@ class HintsFeature:
...
@@ -171,13 +171,13 @@ class HintsFeature:
# 2) we are putting things back after a failed transaction.
# 2) we are putting things back after a failed transaction.
class
HintsOptimizer
(
Optimizer
):
class
HintsOptimizer
(
Global
Optimizer
):
"""
"""
Optimizer that serves to add HintsFeature as an fgraph feature.
Optimizer that serves to add HintsFeature as an fgraph feature.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
HintsFeature
())
fgraph
.
attach_feature
(
HintsFeature
())
...
...
theano/scan/opt.py
浏览文件 @
55ca059a
...
@@ -63,7 +63,11 @@ from theano.compile import optdb
...
@@ -63,7 +63,11 @@ from theano.compile import optdb
from
theano.compile.function.types
import
deep_copy_op
from
theano.compile.function.types
import
deep_copy_op
from
theano.gof
import
DestroyHandler
,
InconsistencyError
,
toolbox
from
theano.gof
import
DestroyHandler
,
InconsistencyError
,
toolbox
from
theano.gof.graph
import
equal_computations
from
theano.gof.graph
import
equal_computations
from
theano.gof.opt
import
Optimizer
,
pre_constant_merge
,
pre_greedy_local_optimizer
from
theano.gof.opt
import
(
GlobalOptimizer
,
pre_constant_merge
,
pre_greedy_local_optimizer
,
)
from
theano.scan.op
import
Scan
from
theano.scan.op
import
Scan
from
theano.scan.utils
import
(
from
theano.scan.utils
import
(
clone
,
clone
,
...
@@ -224,14 +228,14 @@ def remove_constants_and_unused_inputs_scan(node):
...
@@ -224,14 +228,14 @@ def remove_constants_and_unused_inputs_scan(node):
# This is a global opt for historical reason
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
# It should be possible to change it to a local opt.
class
PushOutNonSeqScan
(
gof
.
Optimizer
):
class
PushOutNonSeqScan
(
gof
.
Global
Optimizer
):
"""
"""
A global optimizer for pushing out the variables inside the scan that depend
A global optimizer for pushing out the variables inside the scan that depend
only on non-sequences.
only on non-sequences.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
...
@@ -440,14 +444,14 @@ class PushOutNonSeqScan(gof.Optimizer):
...
@@ -440,14 +444,14 @@ class PushOutNonSeqScan(gof.Optimizer):
# This is a global opt for historical reason
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
# It should be possible to change it to a local opt.
class
PushOutSeqScan
(
gof
.
Optimizer
):
class
PushOutSeqScan
(
gof
.
Global
Optimizer
):
"""
"""
A global optimizer for pushing out the variables inside the
A global optimizer for pushing out the variables inside the
scan that depend only on constants and sequences.
scan that depend only on constants and sequences.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
...
@@ -696,14 +700,14 @@ class PushOutSeqScan(gof.Optimizer):
...
@@ -696,14 +700,14 @@ class PushOutSeqScan(gof.Optimizer):
return
False
return
False
class
PushOutScanOutput
(
gof
.
Optimizer
):
class
PushOutScanOutput
(
gof
.
Global
Optimizer
):
"""
"""
This is an optimization that can push operations performed
This is an optimization that can push operations performed
at the end of the inner graph of scan to outside of scan.
at the end of the inner graph of scan to outside of scan.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
...
@@ -958,14 +962,14 @@ class PushOutScanOutput(gof.Optimizer):
...
@@ -958,14 +962,14 @@ class PushOutScanOutput(gof.Optimizer):
return
new_scan_node
return
new_scan_node
class
ScanInplaceOptimizer
(
Optimizer
):
class
ScanInplaceOptimizer
(
Global
Optimizer
):
"""
"""
Graph optimizer for Scan (makes it run inplace).
Graph optimizer for Scan (makes it run inplace).
"""
"""
def
__init__
(
self
,
typeInfer
=
None
,
gpua_flag
=
False
):
def
__init__
(
self
,
typeInfer
=
None
,
gpua_flag
=
False
):
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
self
.
typeInfer
=
typeInfer
self
.
typeInfer
=
typeInfer
self
.
gpua_flag
=
gpua_flag
self
.
gpua_flag
=
gpua_flag
...
@@ -1124,14 +1128,14 @@ class ScanInplaceOptimizer(Optimizer):
...
@@ -1124,14 +1128,14 @@ class ScanInplaceOptimizer(Optimizer):
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
node
=
self
.
attempt_scan_inplace
(
fgraph
,
node
,
[
pos
],
alloc_ops
)
class
ScanSaveMem
(
gof
.
Optimizer
):
class
ScanSaveMem
(
gof
.
Global
Optimizer
):
"""
"""
Graph Optimizer that reduces scan memory consumption.
Graph Optimizer that reduces scan memory consumption.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
gof
.
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
gof
.
toolbox
.
ReplaceValidate
())
...
@@ -1680,7 +1684,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -1680,7 +1684,7 @@ class ScanSaveMem(gof.Optimizer):
self
.
process_node
(
fgraph
,
node
)
self
.
process_node
(
fgraph
,
node
)
class
ScanMerge
(
gof
.
Optimizer
):
class
ScanMerge
(
gof
.
Global
Optimizer
):
"""
"""
Graph Optimizer that merges different scan ops.
Graph Optimizer that merges different scan ops.
...
@@ -2135,14 +2139,14 @@ def scan_merge_inouts(node):
...
@@ -2135,14 +2139,14 @@ def scan_merge_inouts(node):
return
na
.
outer_outputs
return
na
.
outer_outputs
class
PushOutDot1
(
gof
.
Optimizer
):
class
PushOutDot1
(
gof
.
Global
Optimizer
):
"""
"""
Graph optimizer for Scan(makes it run inplace).
Graph optimizer for Scan(makes it run inplace).
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
fgraph
.
attach_feature
(
toolbox
.
ReplaceValidate
())
...
...
theano/tensor/blas.py
浏览文件 @
55ca059a
...
@@ -147,9 +147,9 @@ from theano.compile.mode import optdb
...
@@ -147,9 +147,9 @@ from theano.compile.mode import optdb
from
theano.gof
import
(
from
theano.gof
import
(
Apply
,
Apply
,
EquilibriumOptimizer
,
EquilibriumOptimizer
,
GlobalOptimizer
,
InconsistencyError
,
InconsistencyError
,
Op
,
Op
,
Optimizer
,
ReplacementDidNotRemoveError
,
ReplacementDidNotRemoveError
,
SequenceDB
,
SequenceDB
,
local_optimizer
,
local_optimizer
,
...
@@ -1449,11 +1449,11 @@ def _gemm_from_node2(node):
...
@@ -1449,11 +1449,11 @@ def _gemm_from_node2(node):
return
None
,
t1
-
t0
,
0
,
0
return
None
,
t1
-
t0
,
0
,
0
class
GemmOptimizer
(
Optimizer
):
class
GemmOptimizer
(
Global
Optimizer
):
"""Graph optimizer for inserting Gemm operations."""
"""Graph optimizer for inserting Gemm operations."""
def
__init__
(
self
):
def
__init__
(
self
):
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
self
.
warned
=
False
self
.
warned
=
False
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
...
...
theano/tensor/opt.py
浏览文件 @
55ca059a
...
@@ -33,7 +33,7 @@ from theano.gof import (
...
@@ -33,7 +33,7 @@ from theano.gof import (
)
)
from
theano.gof.op
import
Op
from
theano.gof.op
import
Op
from
theano.gof.opt
import
(
from
theano.gof.opt
import
(
Optimizer
,
Global
Optimizer
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
local_optimizer
,
local_optimizer
,
...
@@ -214,7 +214,7 @@ def broadcast_like(value, template, fgraph, dtype=None):
...
@@ -214,7 +214,7 @@ def broadcast_like(value, template, fgraph, dtype=None):
return
rval
return
rval
class
InplaceElemwiseOptimizer
(
Optimizer
):
class
InplaceElemwiseOptimizer
(
Global
Optimizer
):
"""
"""
We parametrise it to make it work for Elemwise and GpuElemwise op.
We parametrise it to make it work for Elemwise and GpuElemwise op.
"""
"""
...
@@ -1664,7 +1664,7 @@ class ShapeFeature:
...
@@ -1664,7 +1664,7 @@ class ShapeFeature:
return
True
return
True
class
ShapeOptimizer
(
Optimizer
):
class
ShapeOptimizer
(
Global
Optimizer
):
"""Optimizer that serves to add ShapeFeature as an fgraph feature."""
"""Optimizer that serves to add ShapeFeature as an fgraph feature."""
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
...
@@ -1674,7 +1674,7 @@ class ShapeOptimizer(Optimizer):
...
@@ -1674,7 +1674,7 @@ class ShapeOptimizer(Optimizer):
pass
pass
class
UnShapeOptimizer
(
Optimizer
):
class
UnShapeOptimizer
(
Global
Optimizer
):
"""Optimizer remove ShapeFeature as an fgraph feature."""
"""Optimizer remove ShapeFeature as an fgraph feature."""
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
...
@@ -7729,11 +7729,11 @@ def elemwise_max_input_fct(node):
...
@@ -7729,11 +7729,11 @@ def elemwise_max_input_fct(node):
local_elemwise_fusion
=
local_elemwise_fusion_op
(
Elemwise
,
elemwise_max_input_fct
)
local_elemwise_fusion
=
local_elemwise_fusion_op
(
Elemwise
,
elemwise_max_input_fct
)
class
FusionOptimizer
(
Optimizer
):
class
FusionOptimizer
(
Global
Optimizer
):
"""Graph optimizer for Fusion of elemwise operations."""
"""Graph optimizer for Fusion of elemwise operations."""
def
__init__
(
self
,
local_optimizer
):
def
__init__
(
self
,
local_optimizer
):
Optimizer
.
__init__
(
self
)
super
()
.
__init__
(
)
self
.
optimizer
=
local_optimizer
self
.
optimizer
=
local_optimizer
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论