Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9bc72d95
提交
9bc72d95
authored
3月 20, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
documented opt.py and added a callback to some Optimizers so the user can know…
documented opt.py and added a callback to some Optimizers so the user can know about failed attempts to optimize
上级
6008b0ec
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
242 行增加
和
112 行删除
+242
-112
opt.py
gof/opt.py
+242
-112
没有找到文件。
gof/opt.py
浏览文件 @
9bc72d95
...
@@ -9,25 +9,51 @@ import ext
...
@@ -9,25 +9,51 @@ import ext
class
Optimizer
:
class
Optimizer
:
"""
An Optimizer can be applied to an env to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an env.
"""
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
Applies the optimization to the provided env. It may
use all the methods defined by the env. If the optimizer
needs to use a certain tool, such as an InstanceFinder,
it should set the __env_require__ field to a list of
what needs to be registered with the Env.
"""
pass
pass
def
optimize
(
self
,
env
):
def
optimize
(
self
,
env
):
"""
This is meant as a shortcut to:
env.satisfy(opt)
opt.apply(env)
"""
env
.
satisfy
(
self
)
env
.
satisfy
(
self
)
self
.
apply
(
env
)
self
.
apply
(
env
)
def
__call__
(
self
,
env
):
self
.
optimize
(
env
)
DummyOpt
=
Optimizer
()
DummyOpt
=
Optimizer
()
DummyOpt
.
__doc__
=
"Does nothing."
class
SeqOptimizer
(
Optimizer
,
list
):
class
SeqOptimizer
(
Optimizer
,
list
):
"""
Takes a list of Optimizer instances and applies them
sequentially.
"""
def
__init__
(
self
,
*
opts
):
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
opts
=
opts
[
0
]
list
.
__init__
(
opts
)
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
Applies each optimizer in self in turn.
"""
for
optimizer
in
self
:
for
optimizer
in
self
:
optimizer
.
optimize
(
env
)
optimizer
.
optimize
(
env
)
...
@@ -40,14 +66,34 @@ class SeqOptimizer(Optimizer, list):
...
@@ -40,14 +66,34 @@ class SeqOptimizer(Optimizer, list):
class
LocalOptimizer
(
Optimizer
):
class
LocalOptimizer
(
Optimizer
):
"""
Generic Optimizer class that considers local parts of
the env. It must be subclassed and should override the
following two methods:
* candidates(env) -> returns a set of ops that can be
optimized
* apply_on_op(env, op) -> for each op in candidates,
this function will be called to perform the actual
optimization.
"""
def
candidates
(
self
,
env
):
def
candidates
(
self
,
env
):
return
env
.
ops
()
"""
Must return a set of ops that can be optimized.
"""
raise
utils
.
AbstractFunctionError
()
def
apply_on_op
(
self
,
env
,
op
):
def
apply_on_op
(
self
,
env
,
op
):
raise
Exception
(
"Please override this function."
)
"""
For each op in candidates, this function will be called to
perform the actual optimization.
"""
raise
utils
.
AbstractFunctionError
()
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
"""
for
op
in
self
.
candidates
(
env
):
for
op
in
self
.
candidates
(
env
):
if
env
.
has_op
(
op
):
if
env
.
has_op
(
op
):
self
.
apply_on_op
(
env
,
op
)
self
.
apply_on_op
(
env
,
op
)
...
@@ -55,50 +101,95 @@ class LocalOptimizer(Optimizer):
...
@@ -55,50 +101,95 @@ class LocalOptimizer(Optimizer):
class
OpSpecificOptimizer
(
LocalOptimizer
):
class
OpSpecificOptimizer
(
LocalOptimizer
):
"""
Generic optimizer that applies only to ops of a certain
type. The type in question is accessed through self.opclass.
opclass can also be a class variable of the subclass.
"""
__env_require__
=
toolbox
.
InstanceFinder
__env_require__
=
toolbox
.
InstanceFinder
opclass
=
Op
def
candidates
(
self
,
env
):
def
candidates
(
self
,
env
):
"""
Returns all instances of self.opclass.
"""
return
env
.
get_instances_of
(
self
.
opclass
)
return
env
.
get_instances_of
(
self
.
opclass
)
class
OpSubOptimizer
(
Optimizer
):
class
OpSubOptimizer
(
Optimizer
):
"""
Replaces all ops of a certain type by ops of another type that
take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
__env_require__
=
toolbox
.
InstanceFinder
__env_require__
=
toolbox
.
InstanceFinder
def
__init__
(
self
,
op1
,
op2
):
def
__init__
(
self
,
op1
,
op2
,
failure_callback
=
None
):
if
not
op1
.
_default_output_idx
>=
0
:
"""
raise
TypeError
(
"OpSubOptimizer must be used with Op instances that have a default output."
)
op1 and op2 must both be Op subclasses, they must both take
# note: op2 must have the same input signature as op1
the same number of inputs and they must both have the same
number of outputs.
"""
self
.
op1
=
op1
self
.
op1
=
op1
self
.
op2
=
op2
self
.
op2
=
op2
self
.
failure_callback
=
failure_callback
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
Replaces all occurrences of self.op1 by instances of self.op2
with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
"""
candidates
=
env
.
get_instances_of
(
self
.
op1
)
candidates
=
env
.
get_instances_of
(
self
.
op1
)
for
op
in
candidates
:
for
op
in
candidates
:
try
:
try
:
# note: only replaces the default 'out' port if it exists
repl
=
self
.
op2
(
*
op
.
inputs
)
r
=
self
.
op2
(
*
op
.
inputs
)
.
out
assert
len
(
op
.
outputs
)
==
len
(
repl
.
outputs
)
env
.
replace
(
op
.
out
,
r
)
for
old
,
new
in
zip
(
op
.
outputs
,
repl
.
outputs
):
except
InconsistencyError
,
e
:
env
.
replace
(
old
,
new
)
# print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
op
,
repl
,
e
)
pass
pass
def
str
(
self
):
return
"
%
s ->
%
s"
%
(
self
.
op1
.
__name__
,
self
.
op2
.
__name__
)
class
OpRemover
(
Optimizer
):
class
OpRemover
(
Optimizer
):
"""
Removes all ops of a certain type by transferring each of its
outputs to the corresponding input.
"""
__env_require__
=
toolbox
.
InstanceFinder
__env_require__
=
toolbox
.
InstanceFinder
def
__init__
(
self
,
opclass
):
def
__init__
(
self
,
opclass
,
failure_callback
=
None
):
"""
opclass is the class of the ops to remove. It must take as
many inputs as outputs.
"""
self
.
opclass
=
opclass
self
.
opclass
=
opclass
self
.
failure_callback
=
failure_callback
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
"""
Removes all occurrences of self.opclass.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
"""
candidates
=
env
.
get_instances_of
(
self
.
opclass
)
candidates
=
env
.
get_instances_of
(
self
.
opclass
)
for
op
in
candidates
:
for
op
in
candidates
:
...
@@ -106,10 +197,14 @@ class OpRemover(Optimizer):
...
@@ -106,10 +197,14 @@ class OpRemover(Optimizer):
assert
len
(
op
.
inputs
)
==
len
(
op
.
outputs
)
assert
len
(
op
.
inputs
)
==
len
(
op
.
outputs
)
for
input
,
output
in
zip
(
op
.
inputs
,
op
.
outputs
):
for
input
,
output
in
zip
(
op
.
inputs
,
op
.
outputs
):
env
.
replace
(
output
,
input
)
env
.
replace
(
output
,
input
)
except
InconsistencyError
,
e
:
except
Exception
,
e
:
# print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
op
,
e
)
pass
pass
def
str
(
self
):
return
"f(
%
s(x)) -> f(x)"
%
self
.
opclass
class
PatternOptimizer
(
OpSpecificOptimizer
):
class
PatternOptimizer
(
OpSpecificOptimizer
):
...
@@ -117,13 +212,26 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -117,13 +212,26 @@ class PatternOptimizer(OpSpecificOptimizer):
Replaces all occurrences of the first pattern by the second pattern.
Replaces all occurrences of the first pattern by the second pattern.
"""
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
):
def
__init__
(
self
,
in_pattern
,
out_pattern
,
failure_callback
=
None
):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self
.
in_pattern
=
in_pattern
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
self
.
out_pattern
=
out_pattern
self
.
opclass
=
self
.
in_pattern
[
0
]
self
.
opclass
=
self
.
in_pattern
[
0
]
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
self
.
failure_callback
=
failure_callback
def
apply_on_op
(
self
,
env
,
op
):
def
apply_on_op
(
self
,
env
,
op
):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
"""
def
match
(
pattern
,
expr
,
u
,
first
=
False
):
def
match
(
pattern
,
expr
,
u
,
first
=
False
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
...
@@ -168,8 +276,9 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -168,8 +276,9 @@ class PatternOptimizer(OpSpecificOptimizer):
if
not
isinstance
(
p
,
str
):
if
not
isinstance
(
p
,
str
):
new
=
new
.
out
new
=
new
.
out
env
.
replace
(
op
.
out
,
new
)
env
.
replace
(
op
.
out
,
new
)
except
InconsistencyError
,
e
:
except
Exception
,
e
:
# print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
op
.
out
,
new
,
e
)
pass
pass
...
@@ -183,6 +292,11 @@ class PatternOptimizer(OpSpecificOptimizer):
...
@@ -183,6 +292,11 @@ class PatternOptimizer(OpSpecificOptimizer):
class
ConstantFinder
(
Optimizer
):
class
ConstantFinder
(
Optimizer
):
"""
Sets as constant every orphan that is not destroyed
and sets as indestructible every input that is not
destroyed.
"""
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
if
env
.
has_feature
(
ext
.
DestroyHandler
):
if
env
.
has_feature
(
ext
.
DestroyHandler
):
...
@@ -202,6 +316,12 @@ class ConstantFinder(Optimizer):
...
@@ -202,6 +316,12 @@ class ConstantFinder(Optimizer):
class
MergeOptimizer
(
Optimizer
):
class
MergeOptimizer
(
Optimizer
):
"""
Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that
are constant.
"""
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
cid
=
{}
cid
=
{}
...
@@ -220,6 +340,9 @@ class MergeOptimizer(Optimizer):
...
@@ -220,6 +340,9 @@ class MergeOptimizer(Optimizer):
inv_cid
[
i
]
=
r
inv_cid
[
i
]
=
r
for
op
in
env
.
io_toposort
():
for
op
in
env
.
io_toposort
():
# this could be made more robust by having an op.hash() that
# doesn't depend on the inputs but can depend on additional properties
# of the op.
op_cid
=
(
op
.
__class__
,
tuple
([
cid
[
input
]
for
input
in
op
.
inputs
]))
op_cid
=
(
op
.
__class__
,
tuple
([
cid
[
input
]
for
input
in
op
.
inputs
]))
dup
=
inv_cid
.
get
(
op_cid
,
None
)
dup
=
inv_cid
.
get
(
op_cid
,
None
)
if
dup
is
None
:
if
dup
is
None
:
...
@@ -237,124 +360,131 @@ class MergeOptimizer(Optimizer):
...
@@ -237,124 +360,131 @@ class MergeOptimizer(Optimizer):
def
MergeOptMerge
(
opt
):
def
MergeOptMerge
(
opt
):
"""
Returns an Optimizer that merges the graph then applies the
optimizer in opt and then merges the graph again in case the
opt introduced additional similarities.
"""
merger
=
MergeOptimizer
()
merger
=
MergeOptimizer
()
return
SeqOptimizer
([
merger
,
opt
,
merger
])
return
SeqOptimizer
([
merger
,
opt
,
merger
])
class
MultiOptimizer
(
Optimizer
):
### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
def
__init__
(
self
,
**
opts
):
# class MultiOptimizer(Optimizer):
self
.
_opts
=
[]
self
.
ord
=
{}
self
.
name_to_opt
=
{}
self
.
up_to_date
=
True
for
name
,
opt
in
opts
:
self
.
register
(
name
,
opt
,
after
=
[],
before
=
[])
def
register
(
self
,
name
,
opt
,
**
relative
):
# def __init__(self, **opts):
self
.
name_to_opt
[
name
]
=
opt
# self._opts = []
# self.ord = {}
# self.name_to_opt = {}
# self.up_to_date = True
# for name, opt in opts:
# self.register(name, opt, after = [], before = [])
after
=
relative
.
get
(
'after'
,
[])
# def register(self, name, opt, **relative):
if
not
isinstance
(
after
,
(
list
,
tuple
)):
# self.name_to_opt[name] = opt
after
=
[
after
]
before
=
relative
.
get
(
'before
'
,
[])
# after = relative.get('after
', [])
if
not
isinstance
(
before
,
(
list
,
tuple
)):
# if not isinstance(after
, (list, tuple)):
before
=
[
before
]
# after = [after
]
self
.
up_to_date
=
False
# before = relative.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
if
name
in
self
.
ord
:
# self.up_to_date = False
raise
Exception
(
"Cannot redefine optimization: '
%
s'"
%
name
)
self
.
ord
[
name
]
=
set
(
after
)
# if name in self.ord:
# raise Exception("Cannot redefine optimization: '%s'" % name)
for
postreq
in
before
:
# self.ord[name] = set(after)
self
.
ord
.
setdefault
(
postreq
,
set
())
.
add
(
name
)
def
get_opts
(
self
):
# for postreq in before:
if
not
self
.
up_to_date
:
# self.ord.setdefault(postreq, set()).add(name)
self
.
refresh
()
return
self
.
_opts
def
refresh
(
self
):
# def get_opts(self):
self
.
_opts
=
[
self
.
name_to_opt
[
name
]
for
name
in
utils
.
toposort
(
self
.
ord
)]
# if not self.up_to_date:
self
.
up_to_date
=
True
# self.refresh()
# return self._opts
def
apply
(
self
,
env
):
# def refresh(self
):
for
opt
in
self
.
opts
:
# self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
opt
.
apply
(
env
)
# self.up_to_date = True
opts
=
property
(
get_opts
)
# def apply(self, env):
# for opt in self.opts:
# opt.apply(env)
# opts = property(get_opts)
class
TaggedMultiOptimizer
(
MultiOptimizer
):
def
__init__
(
self
,
**
opts
):
# class TaggedMultiOptimizer(MultiOptimizer):
self
.
tags
=
{}
MultiOptimizer
.
__init__
(
self
,
**
opts
)
def
register
(
self
,
name
,
opt
,
tags
=
[],
**
relative
):
# def __init__(self, **opts):
tags
=
set
(
tags
)
# self.tags = {}
tags
.
add
(
name
)
# MultiOptimizer.__init__(self, **opts)
self
.
tags
[
opt
]
=
tags
MultiOptimizer
.
register
(
self
,
name
,
opt
,
**
relative
)
def
filter
(
self
,
whitelist
,
blacklist
):
# def register(self, name, opt, tags = [], **relative):
return
[
opt
for
opt
in
self
.
opts
# tags = set(tags)
if
self
.
tags
[
opt
]
.
intersection
(
whitelist
)
# tags.add(name)
and
not
self
.
tags
[
opt
]
.
intersection
(
blacklist
)]
# self.tags[opt] = tags
# MultiOptimizer.register(self, name, opt, **relative)
def
whitelist
(
self
,
*
tags
):
# def filter(self, whitelist, blacklist):
return
[
opt
for
opt
in
self
.
opts
if
self
.
tags
[
opt
]
.
intersection
(
tags
)]
# return [opt for opt in self.opts
# if self.tags[opt].intersection(whitelist)
# and not self.tags[opt].intersection(blacklist)]
def
black
list
(
self
,
*
tags
):
# def white
list(self, *tags):
return
[
opt
for
opt
in
self
.
opts
if
not
self
.
tags
[
opt
]
.
intersection
(
tags
)]
# return [opt for opt in self.opts if
self.tags[opt].intersection(tags)]
# def blacklist(self, *tags):
# return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
class
TagFilterMultiOptimizer
(
Optimizer
):
def
__init__
(
self
,
all
,
whitelist
=
None
,
blacklist
=
None
):
# class TagFilterMultiOptimizer(Optimizer):
self
.
all
=
all
if
whitelist
is
not
None
:
# def __init__(self, all, whitelist = None, blacklist = None):
self
.
whitelist
=
set
(
whitelist
)
# self.all = all
else
:
self
.
whitelist
=
None
if
blacklist
is
not
None
:
# if whitelist is not None:
self
.
blacklist
=
set
(
blacklist
)
# self.whitelist = set(whitelist)
else
:
# else:
self
.
blacklist
=
set
()
# self.whitelist = None
def
use_whitelist
(
self
,
use
=
True
):
if
self
.
whitelist
is
None
and
use
:
self
.
whitelist
=
set
()
def
allow
(
self
,
*
tags
):
if
self
.
whitelist
is
not
None
:
self
.
whitelist
.
update
(
tags
)
self
.
blacklist
.
difference_update
(
tags
)
def
deny
(
self
,
*
tags
):
if
self
.
whitelist
is
not
None
:
self
.
whitelist
.
difference_update
(
tags
)
self
.
blacklist
.
update
(
tags
)
def
dont_care
(
self
,
*
tags
):
if
self
.
whitelist
is
not
None
:
self
.
whitelist
.
difference_update
(
tags
)
self
.
blacklist
.
difference_update
(
tags
)
def
opts
(
self
):
if
self
.
whitelist
is
not
None
:
return
self
.
all
.
filter
(
self
.
whitelist
,
self
.
blacklist
)
else
:
return
self
.
all
.
blacklist
(
*
[
tag
for
tag
in
self
.
blacklist
])
def
apply
(
self
,
env
):
# if blacklist is not None:
for
opt
in
self
.
opts
():
# self.blacklist = set(blacklist)
opt
.
apply
(
env
)
# else:
# self.blacklist = set()
# def use_whitelist(self, use = True):
# if self.whitelist is None and use:
# self.whitelist = set()
# def allow(self, *tags):
# if self.whitelist is not None:
# self.whitelist.update(tags)
# self.blacklist.difference_update(tags)
# def deny(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.update(tags)
# def dont_care(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.difference_update(tags)
# def opts(self):
# if self.whitelist is not None:
# return self.all.filter(self.whitelist, self.blacklist)
# else:
# return self.all.blacklist(*[tag for tag in self.blacklist])
# def apply(self, env):
# for opt in self.opts():
# opt.apply(env)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论