Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
55c5d0b3
提交
55c5d0b3
authored
3月 16, 2009
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
差异文件
merge
上级
b914ef2b
09bce224
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
218 行增加
和
25 行删除
+218
-25
optimization.txt
doc/tutorials/advanced/optimization.txt
+37
-0
opt.py
theano/tensor/opt.py
+181
-25
没有找到文件。
doc/tutorials/advanced/optimization.txt
浏览文件 @
55c5d0b3
...
@@ -211,12 +211,49 @@ for this somewhere in the future.
...
@@ -211,12 +211,49 @@ for this somewhere in the future.
Local optimization
Local optimization
------------------
------------------
The local version of the above code would be the following:
.. code-block:: python
class LocalSimplify(gof.LocalOptimizer):
def transform(self, node):
if node.op == div:
x, y = node.inputs
if x.owner and x.owner.op == mul:
a, b = x.owner.inputs
if y == a:
return [b]
elif y == b:
return [a]
return False
local_simplify = LocalSimplify()
The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made,
False must be returned. Else, a list of what to replace the node's
outputs with must be returned.
In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. You can follow this :ref:`link <navigator>`
for further documentation, but basically a Navigator is a global
optimizer that loops through all nodes in the graph (or a well-defined
subset of them) and applies one or several local optimizers on them.
>>> x = double('x')
>>> y = double('y')
>>> z = double('z')
>>> a = add(z, mul(div(mul(y, x), y), div(z, x)))
>>> e = gof.Env([x, y, z], [a])
>>> e
[add(z, mul(div(mul(y, x), y), div(z, x)))]
>>> simplify = gof.TopoOptimizer([local_simplify])
>>> simplify.optimize(e)
>>> e
[add(z, mul(x, div(z, x)))]
TODO: test this.
The optimization database (optdb)
The optimization database (optdb)
...
...
theano/tensor/opt.py
浏览文件 @
55c5d0b3
...
@@ -381,19 +381,21 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -381,19 +381,21 @@ class Canonizer(gof.LocalOptimizer):
Usage: Canonizer(main, inverse, reciprocal, calculate)
Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and takes
* main: a suitable Op class that is commutative, associative and
one to an arbitrary number of inputs, e.g. Add or Mul
takes one to an arbitrary number of inputs, e.g. add or
mul
* inverse: an Op class such that inverse(main(x, y), y) == x
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
e.g. sub or div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
* reciprocal: a function such that main(x, reciprocal(y)) ==
e.g. Neg or Inv
inverse(x, y) e.g. neg or inv
* calculate: function that takes a list of numpy.ndarray instances for
* calculate: function that takes a list of numpy.ndarray instances
the numerator, another list for the denumerator, and calculates
for the numerator, another list for the denumerator,
inverse(main(*num), main(*denum)). It takes a keyword argument,
and calculates inverse(main(*num), main(*denum)). It
aslist. If True, the value should be returned as a list of one
takes a keyword argument, aslist. If True, the value
element, unless the value is such that value = main(). In that
should be returned as a list of one element, unless
case, the return value should be an empty list.
the value is such that value = main(). In that case,
the return value should be an empty list.
The result is a local_optimizer. It is best used with a TopoOptimizer in
The result is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order.
in_to_out order.
...
@@ -422,38 +424,114 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -422,38 +424,114 @@ class Canonizer(gof.LocalOptimizer):
self
.
use_reciprocal
=
use_reciprocal
self
.
use_reciprocal
=
use_reciprocal
def
tracks
(
self
):
def
tracks
(
self
):
#return [[None], [None, None], [None]*3, [None]*4, [None]*5]
return
[[
self
.
main
,
None
],
[
self
.
inverse
,
None
],
[
self
.
reciprocal
,
None
]]
return
[[
self
.
main
,
None
],
[
self
.
inverse
,
None
],
[
self
.
reciprocal
,
None
]]
def
get_num_denum
(
self
,
input
):
def
get_num_denum
(
self
,
input
):
"""
This extract two lists, num and denum, such that the input is:
self.inverse(self.main(*num), self.main(*denum)). It returns
the two lists in a (num, denum) pair.
For example, for main, inverse and reciprocal = *, / and inv(),
input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
"""
if
input
.
owner
is
None
or
input
.
owner
.
op
not
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
input
.
owner
is
None
or
input
.
owner
.
op
not
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
T
.
DimShuffle
):
if
input
.
owner
and
isinstance
(
input
.
owner
.
op
,
T
.
DimShuffle
):
dsn
=
input
.
owner
# If input is a DimShuffle of some input which does something like this:
dsop
=
dsn
.
op
# * change a vector of length N into a 1xN row matrix
dsi0
=
dsn
.
inputs
[
0
]
# * change a scalar into a 1x1x1 tensor
# * in general, complete the shape of a tensor with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return the num/denum of its input
dsn
=
input
.
owner
# dimshuffle node
dsop
=
dsn
.
op
# dimshuffle op
dsi0
=
dsn
.
inputs
[
0
]
# the first input of the dimshuffle i.e. the ndarray to redim
# The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
# That kind of DimShuffle only adds broadcastable
# dimensions on the left, without discarding any
# existing broadcastable dimension and is inserted
# automatically by Elemwise when the inputs have
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# later on).
compatible_order
=
(
'x'
,)
*
(
input
.
type
.
ndim
-
dsi0
.
type
.
ndim
)
+
tuple
(
range
(
dsi0
.
type
.
ndim
))
compatible_order
=
(
'x'
,)
*
(
input
.
type
.
ndim
-
dsi0
.
type
.
ndim
)
+
tuple
(
range
(
dsi0
.
type
.
ndim
))
if
dsop
.
new_order
==
compatible_order
:
if
dsop
.
new_order
==
compatible_order
:
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
return
self
.
get_num_denum
(
input
.
owner
.
inputs
[
0
])
return
self
.
get_num_denum
(
input
.
owner
.
inputs
[
0
])
else
:
else
:
# This is when the input isn't produced by main, inverse or reciprocal.
return
[
input
],
[]
return
[
input
],
[]
else
:
else
:
return
[
input
],
[]
return
[
input
],
[]
num
=
[]
num
=
[]
denum
=
[]
denum
=
[]
parent
=
input
.
owner
parent
=
input
.
owner
# We get the (num, denum) pairs for each input
pairs
=
[
self
.
get_num_denum
(
input
)
for
input
in
parent
.
inputs
]
pairs
=
[
self
.
get_num_denum
(
input
)
for
input
in
parent
.
inputs
]
if
parent
.
op
==
self
.
main
:
if
parent
.
op
==
self
.
main
:
# If we have main(x, y), numx, denumx, numy and denumy
# then num is concat(numx, numy) and denum is concat(denumx, denumy)
# note that main() can have any number of arguments >= 0
# concat is list concatenation
num
=
reduce
(
list
.
__iadd__
,
map
(
operator
.
itemgetter
(
0
),
pairs
))
num
=
reduce
(
list
.
__iadd__
,
map
(
operator
.
itemgetter
(
0
),
pairs
))
denum
=
reduce
(
list
.
__iadd__
,
map
(
operator
.
itemgetter
(
1
),
pairs
))
denum
=
reduce
(
list
.
__iadd__
,
map
(
operator
.
itemgetter
(
1
),
pairs
))
elif
parent
.
op
==
self
.
inverse
:
elif
parent
.
op
==
self
.
inverse
:
# If we have inverse(x, y), numx, denumx, numy and denumy
# then num is concat(numx, denumy) and denum is concat(denumx, numy)
# note that inverse() is binary
num
=
pairs
[
0
][
0
]
+
pairs
[
1
][
1
]
num
=
pairs
[
0
][
0
]
+
pairs
[
1
][
1
]
denum
=
pairs
[
0
][
1
]
+
pairs
[
1
][
0
]
denum
=
pairs
[
0
][
1
]
+
pairs
[
1
][
0
]
elif
parent
.
op
==
self
.
reciprocal
:
elif
parent
.
op
==
self
.
reciprocal
:
# If we have reciprocal(x), numx, denumx
# then num is denumx and denum is numx
# note that reciprocal() is unary
num
=
pairs
[
0
][
1
]
num
=
pairs
[
0
][
1
]
denum
=
pairs
[
0
][
0
]
denum
=
pairs
[
0
][
0
]
return
num
,
denum
return
num
,
denum
def
merge_num_denum
(
self
,
num
,
denum
):
def
merge_num_denum
(
self
,
num
,
denum
):
"""
Utility function which takes two lists, num and denum, and
returns something which is equivalent to inverse(main(*num),
main(*denum)), but depends on the length of num and the length
of denum (in order to minimize the number of operations).
Let n = len(num) and d = len(denum):
n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))
Given the values of n and d to which they are associated, all
of the above are equivalent to:
inverse(main(*num), main(*denum))
"""
ln
,
ld
=
len
(
num
),
len
(
denum
)
ln
,
ld
=
len
(
num
),
len
(
denum
)
if
not
ln
and
not
ld
:
if
not
ln
and
not
ld
:
return
T
.
as_tensor
(
self
.
calculate
([],
[]))
return
T
.
as_tensor
(
self
.
calculate
([],
[]))
...
@@ -475,20 +553,52 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -475,20 +553,52 @@ class Canonizer(gof.LocalOptimizer):
@classmethod
@classmethod
def
get_constant
(
cls
,
v
):
def
get_constant
(
cls
,
v
):
"""
Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Result, returns None.
"""
if
isinstance
(
v
,
N
.
generic
):
if
isinstance
(
v
,
N
.
generic
):
return
v
return
v
# doesn't the not hasattr() condition below catch this?
if
isinstance
(
v
,
gof
.
Constant
):
if
isinstance
(
v
,
gof
.
Constant
):
return
v
.
data
return
v
.
data
if
not
hasattr
(
v
,
'owner'
):
if
not
hasattr
(
v
,
'owner'
):
return
v
return
v
if
v
.
owner
and
isinstance
(
v
.
owner
.
op
,
DimShuffle
):
return
cls
.
get_constant
(
v
.
owner
.
inputs
[
0
])
# NOTE: the following code was buggy, but while I was fixing
# it I realized it is probably made useless by constant
# folding, so screw that. Commented-out code is the half-fixed
# version.
# if v.owner and isinstance(v.owner.op, DimShuffle):
# # see the comments in get_num_denum
# # TODO: this should apply the
# dsn = v.owner
# dsop = dsn.op
# dsi0 = dsn.inputs[0]
# compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
# if dsop.new_order == compatible_order:
# return cls.get_constant(v.owner.inputs[0])
return
None
return
None
def
simplify
(
self
,
num
,
denum
):
def
simplify
(
self
,
num
,
denum
):
"""
Shorthand for: self.simplify_constants(*self.simplify_factors(num, denum))
"""
return
self
.
simplify_constants
(
*
self
.
simplify_factors
(
num
,
denum
))
return
self
.
simplify_constants
(
*
self
.
simplify_factors
(
num
,
denum
))
def
simplify_factors
(
self
,
num
,
denum
):
def
simplify_factors
(
self
,
num
,
denum
):
"""
For any Result r which is both in num and denum, removes it
from both lists. Modifies the lists inplace. Returns the
modified lists. For example:
[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]
"""
for
v
in
list
(
num
):
for
v
in
list
(
num
):
if
v
in
denum
:
if
v
in
denum
:
num
.
remove
(
v
)
num
.
remove
(
v
)
...
@@ -496,28 +606,64 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -496,28 +606,64 @@ class Canonizer(gof.LocalOptimizer):
return
num
,
denum
return
num
,
denum
def
simplify_constants
(
self
,
orig_num
,
orig_denum
):
def
simplify_constants
(
self
,
orig_num
,
orig_denum
):
"""
Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single
constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is
removed from the numerator. Examples:
Let main be multiplication:
[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
"""
# Lists representing the numerator and denumerator
num
,
denum
=
list
(
orig_num
),
list
(
orig_denum
)
num
,
denum
=
list
(
orig_num
),
list
(
orig_denum
)
# Lists representing the *constant* elements of num and denum
numct
,
denumct
=
[],
[]
numct
,
denumct
=
[],
[]
ncc
,
dcc
=
0
,
0
for
v
in
orig_num
:
for
v
in
orig_num
:
ct
=
self
.
get_constant
(
v
)
ct
=
self
.
get_constant
(
v
)
if
ct
is
not
None
:
if
ct
is
not
None
:
ncc
+=
1
# We found a constant in the numerator!
# We remove it from num
num
.
remove
(
v
)
num
.
remove
(
v
)
# We add it to numct
numct
.
append
(
ct
)
numct
.
append
(
ct
)
for
v
in
orig_denum
:
for
v
in
orig_denum
:
ct
=
self
.
get_constant
(
v
)
ct
=
self
.
get_constant
(
v
)
if
ct
is
not
None
:
if
ct
is
not
None
:
dcc
+=
1
denum
.
remove
(
v
)
denum
.
remove
(
v
)
denumct
.
append
(
ct
)
denumct
.
append
(
ct
)
if
self
.
use_reciprocal
or
num
:
if
self
.
use_reciprocal
or
num
:
# This will calculate either:
# [inverse(main(*numct), main(*denumct))]
# [] - if inverse(main(*numct), main(*denumct)) is the neutral element
ct
=
self
.
calculate
(
numct
,
denumct
,
aslist
=
True
)
ct
=
self
.
calculate
(
numct
,
denumct
,
aslist
=
True
)
else
:
else
:
# This happens if we don't allow the reciprocal and the
# numerator is empty. That means we will need to represent
# reciprocal(x) like inverse(neutral_element, x) so
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct
=
[
self
.
calculate
(
numct
,
denumct
,
aslist
=
False
)]
ct
=
[
self
.
calculate
(
numct
,
denumct
,
aslist
=
False
)]
# if len(ct) and ncc == 1 and dcc == 0:
# TODO: why are we not wrapping ct in a gof.Constant right now?
# return orig_num, orig_denum
if
orig_num
and
len
(
numct
)
==
1
and
ct
and
N
.
all
(
ct
==
self
.
get_constant
(
orig_num
[
0
])):
if
orig_num
and
len
(
numct
)
==
1
and
len
(
denumct
)
==
0
and
ct
and
N
.
all
(
ct
==
self
.
get_constant
(
orig_num
[
0
])):
# this is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on the denominator
# * it's not the neutral element (ct is an empty list in that case)
# * the constant is the same as the first argument in the numerator
# Then we return very exactly the original num/denum
# If we don't do that the optimizer will just loop infinitely because
# it will not catch on that there are no changes to be made and everytime
# it will want to replace something by the same thing...
return
orig_num
,
orig_denum
return
orig_num
,
orig_denum
return
ct
+
num
,
denum
return
ct
+
num
,
denum
...
@@ -528,6 +674,11 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -528,6 +674,11 @@ class Canonizer(gof.LocalOptimizer):
if
op
not
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
if
op
not
in
[
self
.
main
,
self
.
inverse
,
self
.
reciprocal
]:
return
False
return
False
# I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have
# inverse operating on an inverse, we can make it so that only
# one inverse is used, so we'll reorganize that.
iops
=
set
(
input
.
owner
.
op
for
input
in
inputs
if
input
.
owner
)
iops
=
set
(
input
.
owner
.
op
for
input
in
inputs
if
input
.
owner
)
reorg
=
False
reorg
=
False
if
op
==
self
.
main
:
if
op
==
self
.
main
:
...
@@ -537,8 +688,11 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -537,8 +688,11 @@ class Canonizer(gof.LocalOptimizer):
elif
op
==
self
.
reciprocal
:
elif
op
==
self
.
reciprocal
:
reorg
=
len
(
iops
.
intersection
([
self
.
inverse
,
self
.
reciprocal
]))
!=
0
reorg
=
len
(
iops
.
intersection
([
self
.
inverse
,
self
.
reciprocal
]))
!=
0
# just in case
assert
len
(
node
.
outputs
)
==
1
assert
len
(
node
.
outputs
)
==
1
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num
,
orig_denum
=
self
.
get_num_denum
(
node
.
outputs
[
0
])
orig_num
,
orig_denum
=
self
.
get_num_denum
(
node
.
outputs
[
0
])
num
,
denum
=
list
(
orig_num
),
list
(
orig_denum
)
num
,
denum
=
list
(
orig_num
),
list
(
orig_denum
)
num
,
denum
=
self
.
simplify
(
num
,
denum
)
num
,
denum
=
self
.
simplify
(
num
,
denum
)
...
@@ -547,6 +701,8 @@ class Canonizer(gof.LocalOptimizer):
...
@@ -547,6 +701,8 @@ class Canonizer(gof.LocalOptimizer):
return
len
(
x
)
==
len
(
y
)
and
all
(
N
.
all
(
xe
==
ye
)
for
xe
,
ye
in
zip
(
x
,
y
))
return
len
(
x
)
==
len
(
y
)
and
all
(
N
.
all
(
xe
==
ye
)
for
xe
,
ye
in
zip
(
x
,
y
))
if
not
reorg
and
same
(
orig_num
,
num
)
and
same
(
orig_denum
,
denum
):
if
not
reorg
and
same
(
orig_num
,
num
)
and
same
(
orig_denum
,
denum
):
# We return False if there are no changes
# TODO: what's the purpose of reorg? isn't same() sufficient?
return
False
return
False
new
=
self
.
merge_num_denum
(
num
,
denum
)
new
=
self
.
merge_num_denum
(
num
,
denum
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论