Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9bc72d95
提交
9bc72d95
authored
3月 20, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
documented opt.py and added a callback to some Optimizers so the user can know…
documented opt.py and added a callback to some Optimizers so the user can know about failed attempts to optimize
上级
6008b0ec
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
243 行增加
和
113 行删除
+243
-113
opt.py
gof/opt.py
+243
-113
没有找到文件。
gof/opt.py
浏览文件 @
9bc72d95
...
...
@@ -9,25 +9,51 @@ import ext
class
Optimizer
:
"""
An Optimizer can be applied to an env to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an env.
"""
def
apply
(
self
,
env
):
"""
Applies the optimization to the provided env. It may
use all the methods defined by the env. If the optimizer
needs to use a certain tool, such as an InstanceFinder,
it should set the __env_require__ field to a list of
what needs to be registered with the Env.
"""
pass
def
optimize
(
self
,
env
):
"""
This is meant as a shortcut to:
env.satisfy(opt)
opt.apply(env)
"""
env
.
satisfy
(
self
)
self
.
apply
(
env
)
def
__call__
(
self
,
env
):
self
.
optimize
(
env
)
DummyOpt
=
Optimizer
()
DummyOpt
.
__doc__
=
"Does nothing."
class
SeqOptimizer
(
Optimizer
,
list
):
"""
Takes a list of Optimizer instances and applies them
sequentially.
"""
def
__init__
(
self
,
*
opts
):
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
opts
=
opts
[
0
]
list
.
__init__
(
opts
)
def
apply
(
self
,
env
):
"""
Applies each optimizer in self in turn.
"""
for
optimizer
in
self
:
optimizer
.
optimize
(
env
)
...
...
@@ -40,14 +66,34 @@ class SeqOptimizer(Optimizer, list):
class
LocalOptimizer
(
Optimizer
):
"""
Generic Optimizer class that considers local parts of
the env. It must be subclassed and should override the
following two methods:
* candidates(env) -> returns a set of ops that can be
optimized
* apply_on_op(env, op) -> for each op in candidates,
this function will be called to perform the actual
optimization.
"""
def
candidates
(
self
,
env
):
return
env
.
ops
()
"""
Must return a set of ops that can be optimized.
"""
raise
utils
.
AbstractFunctionError
()
def
apply_on_op
(
self
,
env
,
op
):
raise
Exception
(
"Please override this function."
)
"""
For each op in candidates, this function will be called to
perform the actual optimization.
"""
raise
utils
.
AbstractFunctionError
()
def
apply
(
self
,
env
):
"""
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
"""
for
op
in
self
.
candidates
(
env
):
if
env
.
has_op
(
op
):
self
.
apply_on_op
(
env
,
op
)
...
...
@@ -55,50 +101,95 @@ class LocalOptimizer(Optimizer):
class
OpSpecificOptimizer
(
LocalOptimizer
):
"""
Generic optimizer that applies only to ops of a certain
type. The type in question is accessed through self.opclass.
opclass can also be a class variable of the subclass.
"""
__env_require__
=
toolbox
.
InstanceFinder
opclass
=
Op
def
candidates
(
self
,
env
):
"""
Returns all instances of self.opclass.
"""
return
env
.
get_instances_of
(
self
.
opclass
)
class
OpSubOptimizer
(
Optimizer
):
"""
Replaces all ops of a certain type by ops of another type that
take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
__env_require__
=
toolbox
.
InstanceFinder
def
__init__
(
self
,
op1
,
op2
):
if
not
op1
.
_default_output_idx
>=
0
:
raise
TypeError
(
"OpSubOptimizer must be used with Op instances that have a default output."
)
# note: op2 must have the same input signature as op1
def
__init__
(
self
,
op1
,
op2
,
failure_callback
=
None
):
"""
op1 and op2 must both be Op subclasses, they must both take
the same number of inputs and they must both have the same
number of outputs.
"""
self
.
op1
=
op1
self
.
op2
=
op2
self
.
failure_callback
=
failure_callback
def
apply
(
self
,
env
):
"""
Replaces all occurrences of self.op1 by instances of self.op2
with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
"""
candidates
=
env
.
get_instances_of
(
self
.
op1
)
for
op
in
candidates
:
try
:
# note: only replaces the default 'out' port if it exists
r
=
self
.
op2
(
*
op
.
inputs
)
.
out
env
.
replace
(
op
.
out
,
r
)
except
InconsistencyError
,
e
:
# print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
repl
=
self
.
op2
(
*
op
.
inputs
)
assert
len
(
op
.
outputs
)
==
len
(
repl
.
outputs
)
for
old
,
new
in
zip
(
op
.
outputs
,
repl
.
outputs
):
env
.
replace
(
old
,
new
)
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
op
,
repl
,
e
)
pass
def
str
(
self
):
return
"
%
s ->
%
s"
%
(
self
.
op1
.
__name__
,
self
.
op2
.
__name__
)
class
OpRemover
(
Optimizer
):
"""
Removes all ops of a certain type by transferring each of its
outputs to the corresponding input.
"""
__env_require__
=
toolbox
.
InstanceFinder
def
__init__
(
self
,
opclass
):
def
__init__
(
self
,
opclass
,
failure_callback
=
None
):
"""
opclass is the class of the ops to remove. It must take as
many inputs as outputs.
"""
self
.
opclass
=
opclass
self
.
failure_callback
=
failure_callback
def
apply
(
self
,
env
):
"""
Removes all occurrences of self.opclass.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
"""
candidates
=
env
.
get_instances_of
(
self
.
opclass
)
for
op
in
candidates
:
...
...
@@ -106,10 +197,14 @@ class OpRemover(Optimizer):
assert
len
(
op
.
inputs
)
==
len
(
op
.
outputs
)
for
input
,
output
in
zip
(
op
.
inputs
,
op
.
outputs
):
env
.
replace
(
output
,
input
)
except
InconsistencyError
,
e
:
# print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
op
,
e
)
pass
def
str
(
self
):
return
"f(
%
s(x)) -> f(x)"
%
self
.
opclass
class
PatternOptimizer
(
OpSpecificOptimizer
):
...
...
@@ -117,13 +212,26 @@ class PatternOptimizer(OpSpecificOptimizer):
Replaces all occurrences of the first pattern by the second pattern.
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
):
def
__init__
(
self
,
in_pattern
,
out_pattern
,
failure_callback
=
None
):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
self
.
opclass
=
self
.
in_pattern
[
0
]
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
self
.
failure_callback
=
failure_callback
def
apply_on_op
(
self
,
env
,
op
):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
"""
def
match
(
pattern
,
expr
,
u
,
first
=
False
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
...
...
@@ -168,8 +276,9 @@ class PatternOptimizer(OpSpecificOptimizer):
if
not
isinstance
(
p
,
str
):
new
=
new
.
out
env
.
replace
(
op
.
out
,
new
)
except
InconsistencyError
,
e
:
# print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
op
.
out
,
new
,
e
)
pass
...
...
@@ -183,6 +292,11 @@ class PatternOptimizer(OpSpecificOptimizer):
class
ConstantFinder
(
Optimizer
):
"""
Sets as constant every orphan that is not destroyed
and sets as indestructible every input that is not
destroyed.
"""
def
apply
(
self
,
env
):
if
env
.
has_feature
(
ext
.
DestroyHandler
):
...
...
@@ -202,6 +316,12 @@ class ConstantFinder(Optimizer):
class
MergeOptimizer
(
Optimizer
):
"""
Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that
are constant.
"""
def
apply
(
self
,
env
):
cid
=
{}
...
...
@@ -220,6 +340,9 @@ class MergeOptimizer(Optimizer):
inv_cid
[
i
]
=
r
for
op
in
env
.
io_toposort
():
# this could be made more robust by having an op.hash() that
# doesn't depend on the inputs but can depend on additional properties
# of the op.
op_cid
=
(
op
.
__class__
,
tuple
([
cid
[
input
]
for
input
in
op
.
inputs
]))
dup
=
inv_cid
.
get
(
op_cid
,
None
)
if
dup
is
None
:
...
...
@@ -237,124 +360,131 @@ class MergeOptimizer(Optimizer):
def
MergeOptMerge
(
opt
):
"""
Returns an Optimizer that merges the graph then applies the
optimizer in opt and then merges the graph again in case the
opt introduced additional similarities.
"""
merger
=
MergeOptimizer
()
return
SeqOptimizer
([
merger
,
opt
,
merger
])
class
MultiOptimizer
(
Optimizer
):
### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
# class MultiOptimizer(Optimizer):
def
__init__
(
self
,
**
opts
):
self
.
_opts
=
[]
self
.
ord
=
{}
self
.
name_to_opt
=
{}
self
.
up_to_date
=
True
for
name
,
opt
in
opts
:
self
.
register
(
name
,
opt
,
after
=
[],
before
=
[])
#
def __init__(self, **opts):
#
self._opts = []
#
self.ord = {}
#
self.name_to_opt = {}
#
self.up_to_date = True
#
for name, opt in opts:
#
self.register(name, opt, after = [], before = [])
def
register
(
self
,
name
,
opt
,
**
relative
):
self
.
name_to_opt
[
name
]
=
opt
#
def register(self, name, opt, **relative):
#
self.name_to_opt[name] = opt
after
=
relative
.
get
(
'after'
,
[])
if
not
isinstance
(
after
,
(
list
,
tuple
)):
after
=
[
after
]
#
after = relative.get('after', [])
#
if not isinstance(after, (list, tuple)):
#
after = [after]
before
=
relative
.
get
(
'before'
,
[])
if
not
isinstance
(
before
,
(
list
,
tuple
)):
before
=
[
before
]
#
before = relative.get('before', [])
#
if not isinstance(before, (list, tuple)):
#
before = [before]
self
.
up_to_date
=
False
#
self.up_to_date = False
if
name
in
self
.
ord
:
raise
Exception
(
"Cannot redefine optimization: '
%
s'"
%
name
)
#
if name in self.ord:
#
raise Exception("Cannot redefine optimization: '%s'" % name)
self
.
ord
[
name
]
=
set
(
after
)
#
self.ord[name] = set(after)
for
postreq
in
before
:
self
.
ord
.
setdefault
(
postreq
,
set
())
.
add
(
name
)
#
for postreq in before:
#
self.ord.setdefault(postreq, set()).add(name)
def
get_opts
(
self
):
if
not
self
.
up_to_date
:
self
.
refresh
()
return
self
.
_opts
#
def get_opts(self):
#
if not self.up_to_date:
#
self.refresh()
#
return self._opts
def
refresh
(
self
):
self
.
_opts
=
[
self
.
name_to_opt
[
name
]
for
name
in
utils
.
toposort
(
self
.
ord
)]
self
.
up_to_date
=
True
#
def refresh(self):
#
self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
#
self.up_to_date = True
def
apply
(
self
,
env
):
for
opt
in
self
.
opts
:
opt
.
apply
(
env
)
#
def apply(self, env):
#
for opt in self.opts:
#
opt.apply(env)
opts
=
property
(
get_opts
)
#
opts = property(get_opts)
class
TaggedMultiOptimizer
(
MultiOptimizer
):
#
class TaggedMultiOptimizer(MultiOptimizer):
def
__init__
(
self
,
**
opts
):
self
.
tags
=
{}
MultiOptimizer
.
__init__
(
self
,
**
opts
)
#
def __init__(self, **opts):
#
self.tags = {}
#
MultiOptimizer.__init__(self, **opts)
def
register
(
self
,
name
,
opt
,
tags
=
[],
**
relative
):
tags
=
set
(
tags
)
tags
.
add
(
name
)
self
.
tags
[
opt
]
=
tags
MultiOptimizer
.
register
(
self
,
name
,
opt
,
**
relative
)
#
def register(self, name, opt, tags = [], **relative):
#
tags = set(tags)
#
tags.add(name)
#
self.tags[opt] = tags
#
MultiOptimizer.register(self, name, opt, **relative)
def
filter
(
self
,
whitelist
,
blacklist
):
return
[
opt
for
opt
in
self
.
opts
if
self
.
tags
[
opt
]
.
intersection
(
whitelist
)
and
not
self
.
tags
[
opt
]
.
intersection
(
blacklist
)]
#
def filter(self, whitelist, blacklist):
#
return [opt for opt in self.opts
#
if self.tags[opt].intersection(whitelist)
#
and not self.tags[opt].intersection(blacklist)]
def
whitelist
(
self
,
*
tags
):
return
[
opt
for
opt
in
self
.
opts
if
self
.
tags
[
opt
]
.
intersection
(
tags
)]
#
def whitelist(self, *tags):
#
return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
def
blacklist
(
self
,
*
tags
):
return
[
opt
for
opt
in
self
.
opts
if
not
self
.
tags
[
opt
]
.
intersection
(
tags
)]
#
def blacklist(self, *tags):
#
return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
class
TagFilterMultiOptimizer
(
Optimizer
):
#
class TagFilterMultiOptimizer(Optimizer):
def
__init__
(
self
,
all
,
whitelist
=
None
,
blacklist
=
None
):
self
.
all
=
all
#
def __init__(self, all, whitelist = None, blacklist = None):
#
self.all = all
if
whitelist
is
not
None
:
self
.
whitelist
=
set
(
whitelist
)
else
:
self
.
whitelist
=
None
if
blacklist
is
not
None
:
self
.
blacklist
=
set
(
blacklist
)
else
:
self
.
blacklist
=
set
()
def
use_whitelist
(
self
,
use
=
True
):
if
self
.
whitelist
is
None
and
use
:
self
.
whitelist
=
set
()
def
allow
(
self
,
*
tags
):
if
self
.
whitelist
is
not
None
:
self
.
whitelist
.
update
(
tags
)
self
.
blacklist
.
difference_update
(
tags
)
def
deny
(
self
,
*
tags
):
if
self
.
whitelist
is
not
None
:
self
.
whitelist
.
difference_update
(
tags
)
self
.
blacklist
.
update
(
tags
)
def
dont_care
(
self
,
*
tags
):
if
self
.
whitelist
is
not
None
:
self
.
whitelist
.
difference_update
(
tags
)
self
.
blacklist
.
difference_update
(
tags
)
def
opts
(
self
):
if
self
.
whitelist
is
not
None
:
return
self
.
all
.
filter
(
self
.
whitelist
,
self
.
blacklist
)
else
:
return
self
.
all
.
blacklist
(
*
[
tag
for
tag
in
self
.
blacklist
])
#
if whitelist is not None:
#
self.whitelist = set(whitelist)
#
else:
#
self.whitelist = None
#
if blacklist is not None:
#
self.blacklist = set(blacklist)
#
else:
#
self.blacklist = set()
#
def use_whitelist(self, use = True):
#
if self.whitelist is None and use:
#
self.whitelist = set()
#
def allow(self, *tags):
#
if self.whitelist is not None:
#
self.whitelist.update(tags)
#
self.blacklist.difference_update(tags)
#
def deny(self, *tags):
#
if self.whitelist is not None:
#
self.whitelist.difference_update(tags)
#
self.blacklist.update(tags)
#
def dont_care(self, *tags):
#
if self.whitelist is not None:
#
self.whitelist.difference_update(tags)
#
self.blacklist.difference_update(tags)
#
def opts(self):
#
if self.whitelist is not None:
#
return self.all.filter(self.whitelist, self.blacklist)
#
else:
#
return self.all.blacklist(*[tag for tag in self.blacklist])
def
apply
(
self
,
env
):
for
opt
in
self
.
opts
():
opt
.
apply
(
env
)
#
def apply(self, env):
#
for opt in self.opts():
#
opt.apply(env)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论