Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bf3df169
提交
bf3df169
authored
4月 17, 2008
作者:
Olivier Breuleux
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixes here and there, doc
上级
7158e6d8
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
166 行增加
和
213 行删除
+166
-213
_test_scalar_opt.py
_test_scalar_opt.py
+4
-3
base_tensor.py
base_tensor.py
+5
-1
opt.py
opt.py
+82
-46
scalar_opt.py
scalar_opt.py
+71
-0
tensor.py
tensor.py
+4
-163
没有找到文件。
_test_scalar_opt.py
浏览文件 @
bf3df169
...
@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
# x, y, z = inputs()
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y)
# # e = (2.0 * x) / (2.0 * y)
#
#
e = (2.0 * x) / (4.0 * y)
# e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z)
# # e = x / (y / z)
# # e = (x * y) / x
# # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x)
# # e = (x / y) * (y / z) * (z / x)
...
@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
...
@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
# # e = (a - b) + (b - c) + (c - d)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# # e = a - b - b + a + b + c + b - c
# e = x + log(y) - x + y
# # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e])
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g
# gof.ConstantFinder().optimize(g)
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs:
reduce(lambda x, y: x + y, (0,) +
inputs)
# addfn = lambda *inputs:
sum(
inputs)
# subfn = lambda x, y: x - y
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
...
...
base_tensor.py
浏览文件 @
bf3df169
...
@@ -58,7 +58,7 @@ class BaseTensor(Result):
...
@@ -58,7 +58,7 @@ class BaseTensor(Result):
# filter
# filter
#
#
def
filter
(
self
,
arr
):
def
filter
(
self
,
arr
):
"""
cast to an L{numpy.ndarray} and ensure arr has correct rank, shape
"""
"""
Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape.
"""
if
not
(
isinstance
(
arr
,
numpy
.
ndarray
)
\
if
not
(
isinstance
(
arr
,
numpy
.
ndarray
)
\
and
arr
.
dtype
==
self
.
dtype
):
and
arr
.
dtype
==
self
.
dtype
):
arr
=
numpy
.
asarray
(
arr
,
dtype
=
self
.
dtype
)
arr
=
numpy
.
asarray
(
arr
,
dtype
=
self
.
dtype
)
...
@@ -102,6 +102,9 @@ class BaseTensor(Result):
...
@@ -102,6 +102,9 @@ class BaseTensor(Result):
# Description for constant folding
# Description for constant folding
#
#
def
desc
(
self
):
def
desc
(
self
):
"""
Returns a hashable description of this BaseTensor.
"""
if
self
.
data
is
not
None
:
if
self
.
data
is
not
None
:
return
(
BaseTensor
,
self
.
dtype
,
self
.
broadcastable
,
self
.
data
.
data
[:])
return
(
BaseTensor
,
self
.
dtype
,
self
.
broadcastable
,
self
.
data
.
data
[:])
else
:
else
:
...
@@ -210,6 +213,7 @@ class BaseTensor(Result):
...
@@ -210,6 +213,7 @@ class BaseTensor(Result):
};
};
"""
"""
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
return
template
%
dict
(
nbits
=
64
,
half_nbits
=
32
)
+
template
%
dict
(
nbits
=
128
,
half_nbits
=
64
)
# todo: use C templating
############################
############################
...
...
opt.py
浏览文件 @
bf3df169
...
@@ -7,6 +7,19 @@ import scalar
...
@@ -7,6 +7,19 @@ import scalar
class
InplaceOptimizer
(
opt
.
OpSpecificOptimizer
):
class
InplaceOptimizer
(
opt
.
OpSpecificOptimizer
):
"""
Usage: inplace_optimizer.optimize(env)
Attempts to replace all Broadcast ops by versions of them
that operate inplace. It operates greedily: for each Broadcast
Op that is encountered, for each output, tries each input to
see if it can operate inplace on that input. If so, makes the
change and go to the next output or Broadcast Op.
Examples:
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
opclass
=
Broadcast
opclass
=
Broadcast
...
@@ -24,6 +37,7 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
...
@@ -24,6 +37,7 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
continue
continue
candidate_inputs
.
remove
(
candidate_input
)
candidate_inputs
.
remove
(
candidate_input
)
op
=
new_op
op
=
new_op
baseline
=
inplace_pattern
break
break
inplace_optimizer
=
InplaceOptimizer
()
inplace_optimizer
=
InplaceOptimizer
()
...
@@ -32,6 +46,8 @@ inplace_optimizer = InplaceOptimizer()
...
@@ -32,6 +46,8 @@ inplace_optimizer = InplaceOptimizer()
class
DimShuffleLifter
(
opt
.
Optimizer
):
class
DimShuffleLifter
(
opt
.
Optimizer
):
"""
"""
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Broadcast operations and merges
"Lifts" DimShuffle through Broadcast operations and merges
consecutive DimShuffles. Basically, applies the following
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
transformations on the whole graph:
...
@@ -47,9 +63,6 @@ class DimShuffleLifter(opt.Optimizer):
...
@@ -47,9 +63,6 @@ class DimShuffleLifter(opt.Optimizer):
seen
=
set
()
seen
=
set
()
def
merge
(
ord1
,
ord2
):
return
[
x
==
'x'
and
'x'
or
ord1
[
x
]
for
x
in
ord2
]
def
lift
(
r
):
def
lift
(
r
):
if
r
in
seen
:
if
r
in
seen
:
return
return
...
@@ -62,6 +75,7 @@ class DimShuffleLifter(opt.Optimizer):
...
@@ -62,6 +75,7 @@ class DimShuffleLifter(opt.Optimizer):
if
isinstance
(
op
,
DimShuffle
):
if
isinstance
(
op
,
DimShuffle
):
in_op
=
op
.
inputs
[
0
]
.
owner
in_op
=
op
.
inputs
[
0
]
.
owner
if
isinstance
(
in_op
,
DimShuffle
):
if
isinstance
(
in_op
,
DimShuffle
):
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
new_order
=
[
x
==
'x'
and
'x'
or
in_op
.
new_order
[
x
]
for
x
in
op
.
new_order
]
new_order
=
[
x
==
'x'
and
'x'
or
in_op
.
new_order
[
x
]
for
x
in
op
.
new_order
]
if
new_order
==
range
(
len
(
new_order
)):
if
new_order
==
range
(
len
(
new_order
)):
repl
=
in_op
.
inputs
[
0
]
repl
=
in_op
.
inputs
[
0
]
...
@@ -71,6 +85,7 @@ class DimShuffleLifter(opt.Optimizer):
...
@@ -71,6 +85,7 @@ class DimShuffleLifter(opt.Optimizer):
lift
(
repl
)
lift
(
repl
)
return
return
elif
isinstance
(
in_op
,
Broadcast
):
elif
isinstance
(
in_op
,
Broadcast
):
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
repl
=
Broadcast
(
in_op
.
scalar_opclass
,
repl
=
Broadcast
(
in_op
.
scalar_opclass
,
[
DimShuffle
(
input
,
op
.
new_order
)
.
out
for
input
in
in_op
.
inputs
],
[
DimShuffle
(
input
,
op
.
new_order
)
.
out
for
input
in
in_op
.
inputs
],
in_op
.
inplace_pattern
)
.
out
in_op
.
inplace_pattern
)
.
out
...
@@ -87,9 +102,24 @@ lift_dimshuffle = DimShuffleLifter()
...
@@ -87,9 +102,24 @@ lift_dimshuffle = DimShuffleLifter()
def
find_cliques
(
env
,
through_broadcast
=
False
):
def
find_cliques
(
env
,
through_broadcast
=
False
):
"""
Usage: find_cliques(env, through_broadcast = False)
Returns a list of pairs where each pair contains a list
of inputs and a list of outputs such that Env(inputs, outputs)
contains nothing but Broadcast Ops.
If through_broadcast is False, the cliques will only be
allowed to broadcast over the inputs, which means, for
example, that vector operations will not be mixed with
matrix operations.
"""
def
seek_from
(
r
):
def
seek_from
(
r
):
# walks through the graph until it encounters a
# non-Broadcast operation or (if through_broadcast
# is False) a Result which needs to be broadcasted.
op
=
r
.
owner
op
=
r
.
owner
if
r
in
env
.
inputs
\
if
r
in
env
.
inputs
\
or
r
in
env
.
orphans
()
\
or
r
in
env
.
orphans
()
\
...
@@ -103,6 +133,10 @@ def find_cliques(env, through_broadcast = False):
...
@@ -103,6 +133,10 @@ def find_cliques(env, through_broadcast = False):
ret
=
set
()
ret
=
set
()
if
not
through_broadcast
:
if
not
through_broadcast
:
# check each dimension over all the inputs - if the broadcastable
# fields are not all 0 or all 1 for a particular dimension, then
# broadcasting will be performed along it on the inputs where the
# value is 1 and we will stop.
if
any
(
any
(
bc
)
and
not
all
(
bc
)
if
any
(
any
(
bc
)
and
not
all
(
bc
)
for
bc
in
zip
(
*
[
input
.
broadcastable
for
input
in
op
.
inputs
])):
for
bc
in
zip
(
*
[
input
.
broadcastable
for
input
in
op
.
inputs
])):
ret
.
update
(
op
.
inputs
)
ret
.
update
(
op
.
inputs
)
...
@@ -111,6 +145,7 @@ def find_cliques(env, through_broadcast = False):
...
@@ -111,6 +145,7 @@ def find_cliques(env, through_broadcast = False):
for
input
in
op
.
inputs
:
for
input
in
op
.
inputs
:
res
=
seek_from
(
input
)
res
=
seek_from
(
input
)
if
res
is
None
:
if
res
is
None
:
# input is a leaf of our search
ret
.
add
(
input
)
ret
.
add
(
input
)
else
:
else
:
ret
.
update
(
res
)
ret
.
update
(
res
)
...
@@ -124,11 +159,14 @@ def find_cliques(env, through_broadcast = False):
...
@@ -124,11 +159,14 @@ def find_cliques(env, through_broadcast = False):
return
return
clique_inputs
=
seek_from
(
r
)
clique_inputs
=
seek_from
(
r
)
if
clique_inputs
is
None
:
if
clique_inputs
is
None
:
# Not in a clique, keep going
op
=
r
.
owner
op
=
r
.
owner
if
op
is
not
None
:
if
op
is
not
None
:
for
input
in
op
.
inputs
:
for
input
in
op
.
inputs
:
find_cliques_helper
(
input
)
find_cliques_helper
(
input
)
else
:
else
:
# We found a clique, add it to the list and
# jump to the leaves.
cliques
.
append
((
clique_inputs
,
[
r
]))
cliques
.
append
((
clique_inputs
,
[
r
]))
for
input
in
clique_inputs
:
for
input
in
clique_inputs
:
find_cliques_helper
(
input
)
find_cliques_helper
(
input
)
...
@@ -142,6 +180,24 @@ def find_cliques(env, through_broadcast = False):
...
@@ -142,6 +180,24 @@ def find_cliques(env, through_broadcast = False):
class
CliqueOptimizer
(
opt
.
Optimizer
):
class
CliqueOptimizer
(
opt
.
Optimizer
):
"""
Usage: CliqueOptimizer(through_broadcast = False,
scalar_optimizer = None,
make_composite = False).optimize(env)
Finds cliques of Broadcast operations in the env and does either
or both of two things:
* Apply scalar_optimizer on the clique as if the clique was a
group of scalar operations. scalar_optimizer can be any optimization
which applies on scalars. If it is None, no optimization is done.
* Replace the clique with a single Op, optimized to perform the
computations properly. If make_composite is False, no such replacement
is done.
Note: it is recommended to run the lift_dimshuffle optimization before
this one.
"""
def
__init__
(
self
,
through_broadcast
=
False
,
scalar_optimizer
=
None
,
make_composite
=
False
):
def
__init__
(
self
,
through_broadcast
=
False
,
scalar_optimizer
=
None
,
make_composite
=
False
):
self
.
through_broadcast
=
through_broadcast
self
.
through_broadcast
=
through_broadcast
...
@@ -152,20 +208,25 @@ class CliqueOptimizer(opt.Optimizer):
...
@@ -152,20 +208,25 @@ class CliqueOptimizer(opt.Optimizer):
if
self
.
scalar_optimizer
is
None
and
not
self
.
make_composite
:
if
self
.
scalar_optimizer
is
None
and
not
self
.
make_composite
:
# there's nothing to do with the cliques...
# there's nothing to do with the cliques...
return
return
cliques
=
find_cliques
(
env
,
self
.
through_broadcast
)
cliques
=
find_cliques
(
env
,
self
.
through_broadcast
)
opt
=
self
.
scalar_optimizer
opt
=
self
.
scalar_optimizer
def
build_scalar_clique
(
r
,
env
,
equiv
):
def
build_scalar_clique
(
r
,
env
,
equiv
):
# Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# structure and equivalent operations. equiv contains the mapping.
if
r
in
equiv
:
if
r
in
equiv
:
return
equiv
[
r
]
return
equiv
[
r
]
op
=
r
.
owner
op
=
r
.
owner
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
# For each leave we make a Scalar of the corresponding dtype
s
=
scalar
.
Scalar
(
dtype
=
r
.
dtype
)
s
=
scalar
.
Scalar
(
dtype
=
r
.
dtype
)
_r
=
r
_r
=
r
if
isinstance
(
r
.
owner
,
DimShuffle
)
and
all
(
x
==
'x'
for
x
in
r
.
owner
.
new_order
):
if
isinstance
(
r
.
owner
,
DimShuffle
)
and
all
(
x
==
'x'
for
x
in
r
.
owner
.
new_order
):
_r
=
r
.
owner
.
inputs
[
0
]
_r
=
r
.
owner
.
inputs
[
0
]
if
(
getattr
(
r
,
'constant'
,
False
)
or
getattr
(
_r
,
'constant'
,
False
))
\
if
(
getattr
(
r
,
'constant'
,
False
)
or
getattr
(
_r
,
'constant'
,
False
))
\
and
_r
.
broadcastable
==
():
and
_r
.
broadcastable
==
():
# If we have a constant tensor we map it to a constant scalar.
s
.
data
=
_r
.
data
s
.
data
=
_r
.
data
s
.
constant
=
True
s
.
constant
=
True
equiv
[
r
]
=
s
equiv
[
r
]
=
s
...
@@ -184,15 +245,18 @@ class CliqueOptimizer(opt.Optimizer):
...
@@ -184,15 +245,18 @@ class CliqueOptimizer(opt.Optimizer):
s_g
=
Env
([
equiv
[
r
]
for
r
in
g
.
inputs
],
s_g
=
Env
([
equiv
[
r
]
for
r
in
g
.
inputs
],
[
equiv
[
r
]
for
r
in
g
.
outputs
])
[
equiv
[
r
]
for
r
in
g
.
outputs
])
if
opt
is
not
None
:
if
opt
is
not
None
:
equiv2
=
dict
()
equiv2
=
dict
()
# reverse mapping, from Scalar Op to Tensor Op
for
k
,
v
in
equiv
.
items
():
for
k
,
v
in
equiv
.
items
():
equiv2
[
v
]
=
k
equiv2
[
v
]
=
k
def
transform
(
op
,
equiv
):
def
transform
(
op
,
equiv
):
# We get a scalar op and we return an equivalent op on tensors.
return
Broadcast
(
op
.
__class__
,
[
equiv
[
input
]
for
input
in
op
.
inputs
])
return
Broadcast
(
op
.
__class__
,
[
equiv
[
input
]
for
input
in
op
.
inputs
])
s_g
.
add_feature
(
sync_to
(
env
,
equiv2
,
transform
))
s_g
.
add_feature
(
sync_to
(
env
,
equiv2
,
transform
))
# Any change to s_g will now be transferred to g
opt
.
optimize
(
s_g
)
opt
.
optimize
(
s_g
)
if
self
.
make_composite
:
if
self
.
make_composite
:
def
follow_inplace
(
r
):
def
follow_inplace
(
r
):
# Tries to find the earliest r2 in g such that r destroys r2
# If no such r2 is found, returns None
op
=
r
.
owner
op
=
r
.
owner
if
op
is
None
or
r
in
g
.
inputs
or
r
in
g
.
orphans
():
if
op
is
None
or
r
in
g
.
inputs
or
r
in
g
.
orphans
():
return
None
return
None
...
@@ -211,6 +275,8 @@ class CliqueOptimizer(opt.Optimizer):
...
@@ -211,6 +275,8 @@ class CliqueOptimizer(opt.Optimizer):
for
i
,
output
in
enumerate
(
g
.
outputs
):
for
i
,
output
in
enumerate
(
g
.
outputs
):
destroyed
=
follow_inplace
(
output
)
destroyed
=
follow_inplace
(
output
)
if
destroyed
is
not
None
and
destroyed
in
g
.
inputs
:
if
destroyed
is
not
None
and
destroyed
in
g
.
inputs
:
# we transfer the inplace operation only if it is
# an input that is destroyed
inplace_pattern
[
i
]
=
g
.
inputs
.
index
(
destroyed
)
inplace_pattern
[
i
]
=
g
.
inputs
.
index
(
destroyed
)
C
=
scalar
.
composite
(
s_g
.
inputs
,
s_g
.
outputs
)
C
=
scalar
.
composite
(
s_g
.
inputs
,
s_g
.
outputs
)
ec
=
Broadcast
(
C
,
g
.
inputs
,
inplace_pattern
=
inplace_pattern
)
ec
=
Broadcast
(
C
,
g
.
inputs
,
inplace_pattern
=
inplace_pattern
)
...
@@ -218,6 +284,17 @@ class CliqueOptimizer(opt.Optimizer):
...
@@ -218,6 +284,17 @@ class CliqueOptimizer(opt.Optimizer):
def
sync_to
(
target
,
equiv
,
transform
):
def
sync_to
(
target
,
equiv
,
transform
):
"""
Usage: sync_to(target, equiv, transform)
* target: an Env
* equiv: a dictionary that maps results and ops to results and ops
in target
* transform: a function that takes (op, equiv) as inputs and
returns a new op.
Returns a Feature that can be added to an Env and mirrors all
modifications to that env with modifications to the target env.
"""
class
Synchronize
(
gof
.
Listener
,
gof
.
Constraint
):
class
Synchronize
(
gof
.
Listener
,
gof
.
Constraint
):
...
@@ -259,44 +336,3 @@ def sync_to(target, equiv, transform):
...
@@ -259,44 +336,3 @@ def sync_to(target, equiv, transform):
return
Synchronize
return
Synchronize
"""
This variable is used in compile.prog as the optimizer for all programs built
using either compile.single, compile.to_func, and compile.prog.
Old code::
if 0:
def optimizer(lst):
begin = gof.SeqOptimizer([])
end = gof.SeqOptimizer([gof.DummyRemover])
seq_opt = gof.SeqOptimizer(begin + lst + end)
return gof.PythonOpt(gof.MergeOptMerge(seq_opt))
if 0:
optimizer_begin = gof.SeqOptimizer([opt for name, opt in [
['double_transpose_eliminator', pattern_opt((transpose, (transpose, 'x')), 'x')],
['addxx_to_twice', pattern_opt((add_elemwise, 'x', 'x'), (twice, 'x'))],
['twice_to_itwice', op_sub(twice, itwice)],
['mulxx_to_sqr', pattern_opt((mul_elemwise, 'x', 'x'), (sqr, 'x'))],
['sqr_to_isqr', op_sub(sqr, isqr)],
['add_to_iadd', op_sub(add_elemwise, iadd_elemwise)],
['add_to_iadd_reverse', pattern_opt((add_elemwise, 'x', 'y'),
(iadd_elemwise, 'y', 'x'))]]])
# ['remove_copies', gof.OpRemover(array_copy)],
# [None, gof.DummyRemover] # has to be at the end
"""
scalar_opt.py
浏览文件 @
bf3df169
...
@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
...
@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
class
Canonizer
(
gof
.
Optimizer
):
class
Canonizer
(
gof
.
Optimizer
):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
mainfn
,
invfn
,
recfn
,
transform
=
None
):
def
__init__
(
self
,
main
,
inverse
,
reciprocal
,
mainfn
,
invfn
,
recfn
,
transform
=
None
):
self
.
main
=
main
self
.
main
=
main
...
@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
...
@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
def
canonize
(
r
):
def
canonize
(
r
):
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
if
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
return
def
flatten
(
r
,
nclients_check
=
True
):
def
flatten
(
r
,
nclients_check
=
True
):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op
=
r
.
owner
op
=
r
.
owner
if
op
is
None
or
r
in
env
.
inputs
or
r
in
env
.
orphans
():
if
op
is
None
or
r
in
env
.
inputs
or
r
in
env
.
orphans
():
return
[
r
],
[]
return
[
r
],
[]
...
@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
...
@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
nums
=
[
x
[
0
]
for
x
in
results
]
nums
=
[
x
[
0
]
for
x
in
results
]
denums
=
[
x
[
1
]
for
x
in
results
]
denums
=
[
x
[
1
]
for
x
in
results
]
elif
isinstance
(
op
,
self
.
inverse
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
elif
isinstance
(
op
,
self
.
inverse
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
# num, denum of the second argument are added to the denum, num respectively
nums
=
[
results
[
0
][
0
],
results
[
1
][
1
]]
nums
=
[
results
[
0
][
0
],
results
[
1
][
1
]]
denums
=
[
results
[
0
][
1
],
results
[
1
][
0
]]
denums
=
[
results
[
0
][
1
],
results
[
1
][
0
]]
elif
isinstance
(
op
,
self
.
reciprocal
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
elif
isinstance
(
op
,
self
.
reciprocal
)
and
(
not
nclients_check
or
env
.
nclients
(
r
)
==
1
):
# num, denum of the sole argument are added to the denum, num respectively
nums
=
[
results
[
0
][
1
]]
nums
=
[
results
[
0
][
1
]]
denums
=
[
results
[
0
][
0
]]
denums
=
[
results
[
0
][
0
]]
else
:
else
:
...
@@ -70,22 +110,29 @@ class Canonizer(gof.Optimizer):
...
@@ -70,22 +110,29 @@ class Canonizer(gof.Optimizer):
canonize
(
input
)
canonize
(
input
)
return
return
# Terms that are both in the num and denum lists cancel each other
for
d
in
list
(
denum
):
for
d
in
list
(
denum
):
if
d
in
list
(
num
):
if
d
in
list
(
num
):
# list.remove only removes the element once
num
.
remove
(
d
)
num
.
remove
(
d
)
denum
.
remove
(
d
)
denum
.
remove
(
d
)
# We identify the constants in num and denum
numct
,
num
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
num
)
numct
,
num
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
num
)
denumct
,
denum
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
denum
)
denumct
,
denum
=
utils
.
partition
(
lambda
factor
:
getattr
(
factor
,
'constant'
,
False
)
and
factor
.
data
is
not
None
,
denum
)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v
=
self
.
invfn
(
self
.
mainfn
(
*
[
x
.
data
for
x
in
numct
]),
self
.
mainfn
(
*
[
x
.
data
for
x
in
denumct
]))
v
=
self
.
invfn
(
self
.
mainfn
(
*
[
x
.
data
for
x
in
numct
]),
self
.
mainfn
(
*
[
x
.
data
for
x
in
denumct
]))
if
v
!=
self
.
neutral
:
if
v
!=
self
.
neutral
:
num
.
insert
(
0
,
C
(
v
))
num
.
insert
(
0
,
C
(
v
))
# We optimize the num and denum lists further if requested
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
num
,
denum
=
self
.
transform
(
env
,
num
,
denum
)
num
,
denum
=
self
.
transform
(
env
,
num
,
denum
)
def
make
(
factors
):
def
make
(
factors
):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n
=
len
(
factors
)
n
=
len
(
factors
)
if
n
==
0
:
if
n
==
0
:
return
None
return
None
...
@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
...
@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
if
numr
is
None
:
if
numr
is
None
:
if
denumr
is
None
:
if
denumr
is
None
:
# Everything cancelled each other so we're left with
# the neutral element.
new_r
=
Scalar
(
dtype
=
r
.
dtype
)
new_r
=
Scalar
(
dtype
=
r
.
dtype
)
new_r
.
constant
=
True
new_r
.
constant
=
True
new_r
.
data
=
self
.
neutral
new_r
.
data
=
self
.
neutral
else
:
else
:
# There's no numerator so we use reciprocal
new_r
=
self
.
reciprocal
(
denumr
)
.
out
new_r
=
self
.
reciprocal
(
denumr
)
.
out
else
:
else
:
if
denumr
is
None
:
if
denumr
is
None
:
...
@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
...
@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
else
:
else
:
new_r
=
self
.
inverse
(
numr
,
denumr
)
.
out
new_r
=
self
.
inverse
(
numr
,
denumr
)
.
out
# Hopefully this won't complain!
env
.
replace
(
r
,
new_r
)
env
.
replace
(
r
,
new_r
)
for
factor
in
num
+
denum
:
for
factor
in
num
+
denum
:
...
@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
...
@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
def
group_powers
(
env
,
num
,
denum
):
def
group_powers
(
env
,
num
,
denum
):
"""
Plugin for Canonizer: use as Canonizer(..., transform = group_powers)
Takes num, denum such that mul(*num) / mul(*denum) is in env
and searches for instances of exp(x) or x**y in order to group
together powers of the same variable. Returns num2, denum2 in
which the grouping has been done.
Note: this function does not modify env.
Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
"""
# maps a base to the list of powers it is raised to in the
# numerator/denominator lists.
num_powers
=
{}
num_powers
=
{}
denum_powers
=
{}
denum_powers
=
{}
def
populate
(
d
,
seq
):
def
populate
(
d
,
seq
):
# For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power).
for
factor
in
list
(
seq
):
for
factor
in
list
(
seq
):
op
=
factor
.
owner
op
=
factor
.
owner
if
op
is
None
or
factor
in
env
.
inputs
or
factor
in
env
.
orphans
():
if
op
is
None
or
factor
in
env
.
inputs
or
factor
in
env
.
orphans
():
...
@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
...
@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
populate
(
denum_powers
,
denum
)
populate
(
denum_powers
,
denum
)
for
x
in
set
(
num_powers
.
keys
()
+
denum_powers
.
keys
()):
for
x
in
set
(
num_powers
.
keys
()
+
denum_powers
.
keys
()):
# we append base ** (num_powers[base] - denum_powers[base])
# to the num list
try
:
num_ys
=
num_powers
.
pop
(
x
)
try
:
num_ys
=
num_powers
.
pop
(
x
)
except
KeyError
:
num_ys
=
[]
except
KeyError
:
num_ys
=
[]
...
@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
...
@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
num_r
=
num_ys
and
add
(
*
num_ys
)
or
C
(
0
)
num_r
=
num_ys
and
add
(
*
num_ys
)
or
C
(
0
)
denum_r
=
denum_ys
and
add
(
*
denum_ys
)
or
C
(
0
)
denum_r
=
denum_ys
and
add
(
*
denum_ys
)
or
C
(
0
)
if
x
==
'e'
:
if
x
==
'e'
:
num
.
append
(
exp
(
num_r
-
denum_r
))
num
.
append
(
exp
(
num_r
-
denum_r
))
else
:
else
:
...
...
tensor.py
浏览文件 @
bf3df169
...
@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
...
@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
if
isinstance
(
data
,
BaseTensor
):
if
isinstance
(
data
,
BaseTensor
):
if
broadcastable
is
not
None
and
list
(
data
.
broadcastable
)
!=
list
(
broadcastable
):
if
broadcastable
is
not
None
and
list
(
data
.
broadcastable
)
!=
list
(
broadcastable
):
raise
TypeError
(
"The data to wrap as a Tensor has the wrong broadcastable pattern. Expected
%
s, got
%
s."
%
(
broadcastable
,
data
.
broadcastable
))
raise
TypeError
(
"The data to wrap as a Tensor has the wrong broadcastable pattern. Expected
%
s, got
%
s."
%
(
broadcastable
,
data
.
broadcastable
))
if
isinstance
(
data
,
Tensor
)
and
(
name
is
None
or
name
==
data
.
name
):
if
name
is
not
None
and
name
!=
data
.
name
:
raise
ValueError
(
"Cannot rename an existing Tensor."
)
return
data
return
data
else
:
t
=
Tensor
(
data
.
dtype
,
data
.
broadcastable
,
name
=
name
)
t
.
data
=
data
return
t
elif
isinstance
(
data
,
Result
):
elif
isinstance
(
data
,
Result
):
data
=
data
.
data
raise
TypeError
(
"Cannot make a Tensor out of a non-Tensor result."
)
if
data
is
None
and
broadcastable
is
None
:
if
data
is
None
and
broadcastable
is
None
:
raise
TypeError
(
"Cannot make a Tensor out of None
or a Result with no data
."
)
raise
TypeError
(
"Cannot make a Tensor out of None."
)
data
=
numpy
.
asarray
(
data
)
data
=
numpy
.
asarray
(
data
)
if
broadcastable
is
None
:
if
broadcastable
is
None
:
...
@@ -107,38 +104,6 @@ s2t.astensor = astensor
...
@@ -107,38 +104,6 @@ s2t.astensor = astensor
# Supporting Ops
# Supporting Ops
############################
############################
def
_scalar_switch
(
normal_f
,
scalar_f
,
scalar_f_reverse
=
None
):
"""a decorator for operators before broadcasting works properly"""
def
f
(
x
,
y
):
def
as_tensor
(
obj
):
if
isinstance
(
obj
,
Tensor
):
return
obj
else
:
return
astensor
(
obj
)
x
,
y
=
as_tensor
(
x
),
as_tensor
(
y
)
if
0
not
in
y
.
broadcastable
:
return
scalar_f
(
x
,
y
)
if
0
not
in
x
.
broadcastable
:
if
scalar_f_reverse
:
return
scalar_f_reverse
(
y
,
x
)
else
:
raise
TypeError
(
"You cannot do this operation on a scalar."
)
return
normal_f
(
x
,
y
)
return
f
def
_assert_same_shapes
(
x
,
*
rest
):
"""Ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)"""
shape
=
x
.
shape
for
other
in
rest
:
if
other
.
shape
!=
shape
:
raise
ValueError
(
_assert_same_shapes
.
E_shape
,
shape
,
other
.
shape
)
_assert_same_shapes
.
E_shape
=
"The dimensions of the inputs do not match."
def
_assert_tensor_scalar
(
x
,
a
):
"""ensure that the second input is a scalar"""
if
numpy
.
product
(
a
.
shape
)
!=
1
:
raise
ValueError
(
"The second argument must be a scalar."
)
# this has a different name, because _as_tensor is the function which ops use
# this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor
=
astensor
_as_tensor
=
astensor
...
@@ -450,8 +415,6 @@ class Gemm(_Op):
...
@@ -450,8 +415,6 @@ class Gemm(_Op):
return
[
'<iostream>'
]
return
[
'<iostream>'
]
def
c_libraries
(
self
):
def
c_libraries
(
self
):
return
blas
.
ldflags
()
return
blas
.
ldflags
()
#def c_var_names(self):
# return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def
c_validate_update
(
self
,
*
args
):
def
c_validate_update
(
self
,
*
args
):
return
""
return
""
def
c_validate_update_cleanup
(
self
,
*
args
):
def
c_validate_update_cleanup
(
self
,
*
args
):
...
@@ -612,125 +575,3 @@ class Gemm(_Op):
...
@@ -612,125 +575,3 @@ class Gemm(_Op):
"""
%
dict
(
locals
(),
**
sub
)
"""
%
dict
(
locals
(),
**
sub
)
gemm
=
gof
.
op
.
constructor
(
Gemm
)
gemm
=
gof
.
op
.
constructor
(
Gemm
)
if
0
:
##########################
# Comparisons
##########################
# Less-than
class
lt_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
lt_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Less-than or equal
class
le_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
le_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Greater-than or equal
class
gt_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
gt_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
# Greater-than or equal
class
ge_elemwise
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
class
ge_scalar_r
(
_Elemwise
):
def
__init__
(
self
,
*
args
):
raise
NotImplementedError
()
if
0
:
def
_broadcastable_pattern
(
pattern
):
def
factory
(
data
=
None
,
name
=
None
,
dtype
=
None
):
if
data
:
assert
len
(
data
.
shape
)
==
len
(
pattern
)
if
dtype
is
not
None
:
assert
dtype
is
data
.
dtype
dtype
=
data
.
dtype
rval
=
Tensor
(
dtype
,
pattern
,
name
)
rval
.
data
=
data
else
:
rval
=
Tensor
(
dtype
,
pattern
,
name
)
return
rval
return
factory
row
=
_broadcastable_pattern
([
1
,
0
])
col
=
_broadcastable_pattern
([
0
,
1
])
matrix
=
_broadcastable_pattern
([
0
,
0
])
if
0
:
#old __init__ code
"""Create a Tensor
If data is given:
- constant defaults to True
- if dtype is given, it must match data.dtype
- otherwise: default is data.dtype
- if broadcastable is given, len(broadcastable) must match len(data.shape)
- otherwise: if it is constant, it defaults to 1 where shape[i]==1
- if it is not constant, it defaults to 0s
If data is not given:
- constant defaults to False
"""
if
dtype
is
None
or
broadcastable
is
None
:
if
data
is
None
:
raise
TypeError
(
"Provide non-None data to complete the dtype and broadcastable flags."
)
data
=
numpy
.
asarray
(
data
)
if
constant
is
None
:
constant
=
True
dtype
=
data
.
dtype
if
constant
:
broadcastable
=
[
1
*
(
x
==
1
)
for
x
in
data
.
shape
]
else
:
broadcastable
=
[
0
]
*
len
(
data
.
shape
)
if
0
:
def
tensor__new__
(
cls
,
*
args
,
**
kwargs
):
"""__new__ is overloaded to handle the special form Tensor(x) when x is
a Tensor or an Op whose default output is a Tensor. In these cases, the
argument x is returned, and a new Tensor is not created.
"""
if
len
(
args
)
==
1
:
a
=
args
[
0
]
t
=
super
(
Tensor
,
cls
)
.
__new__
(
cls
,
*
args
,
**
kwargs
)
t
.
__init__
(
*
args
,
**
kwargs
)
return
t
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论