Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6f5a1844
提交
6f5a1844
authored
7月 20, 2011
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
documentation and revisions to blas optimization pipeline
上级
d36528cb
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
237 行增加
和
66 行删除
+237
-66
blas.py
theano/tensor/blas.py
+219
-52
test_blas.py
theano/tensor/tests/test_blas.py
+18
-14
没有找到文件。
theano/tensor/blas.py
浏览文件 @
6f5a1844
"""Ops and optimizations for using BLAS function calls to evaluate linear algebra expressions"""
"""Ops and optimizations for using BLAS calls
BLAS = Basic Linear Algebra Subroutines
Learn more about BLAS here:
http://www.netlib.org/blas/blast-forum/
The standard BLAS libraries implement what is called "legacy BLAS" in that
document.
This documentation section describes Theano's BLAS optimization
pipeline.
Where there is a discrepancy between how things do work and how they *should*
work, both aspects should be documented. It helps keep a broader agenda in view
even while fixing little bugs etc. from day to day.
Ops
===
There are two BLAS calls wrapped in this module: GEMM and GEMV.
GEMM: Dot22, Dot22Scalar, GemmRelated, Gemm
-------------------------------------------
The BLAS GEMM operation implements Z <- a X Y + b Z,
where Z, X and Y are matrices, and a and b are scalars.
Dot22 is a GEMM where a=1, b=0, and Z is allocated every time.
Dot22Scalar is a GEMM where b=0 and Z is allocated every time.
Gemm is a GEMM in all its generality.
In the future we can refactor the GemmRelated, Gemm, Dot22 and
Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a normal Gemm, but
with an additional configuration variable that says to ignore the input Z.
Setting that configuration variable to True would make Gemm2 equivalent to the
current Dot22 and Dot22Scalar. This would make the file a lot easier to read,
and save a few hundred lines of library, to say nothing of testing and
documentation.
GEMV: Gemv
----------
The BLAS GEMV operation implements Z <- a X Y + b Z,
where Z is a matrix, Y, and Z are vectors, and a and b are scalars.
Gemv implements the GEMV call in all its generality.
Other Notable BLAS-related Ops
------------------------------
GpuOuter is currently a wrapper around GER. GER is a useful special case of
GEMM, and in the future it would be good to have a GER Op. With a GER Op here,
the GpuOuter could be turned into a GpuGER.
SYRK is another useful special case of GEMM. Particularly SYRK preserves
symmetry in the matrix that it updates. See how the linear-algebra module uses
symmetry hints before implementing this Op, so that this Op is compatible with
that system.
Optimizations
=============
The current optimization pipeline is not exactly clear to me. Instead I will
describe how it should work.
The high level pipeline is:
1. identify dot22 from dot
2. identify gemm from dot22
3. identify dot22scalar from dot22 that are not gemm
4. specialize gemm to gemv where applicable
Identify Dot22
--------------
Numpy's dot supports arguments that are of any rank, and we should support that
too (just for compatibility). The BLAS optimizations work with Dot Ops whose
inputs are each either vector or matrix. So the first part of the optimization
pipeline is to transform qualifying Dot Ops to Dot22 Ops. Dot22 Ops may be
transformed further, but they will get implemented by a BLAS call.
More precisely, Dot nodes whose inputs are all vectors or matrices and whose
inputs both have the same dtype, and whose dtype is float or complex, become
Dot22. This is implemented in `local_dot_to_dot22`.
Identify Gemm from Dot22
------------------------
This is complicated, done in GemmOptimizer.
Identify Dot22Scalar from Dot22
-------------------------------
Dot22 Ops that remain after the GemmOptimizer is done have not qualified as GEMM
Ops. Still they might be scaled by a factor, in which case we use Dot22Scalar
which is like Gemm, but without the b and the Z. In the future it would be good
to merge this into the GemmOptimizer.
Specialize Gemm to Gemv
-----------------------
If arguments to GEMM are dimshuffled vectors, then we can use GEMV instead. This
optimization is `local_gemm_to_gemv`.
"""
import
logging
,
copy
,
os
...
...
@@ -456,7 +565,10 @@ class Gemm(GemmRelated):
"""
E_rank
=
'gemm only works for rank 2'
E_scalar
=
'gemm requires scalar argument'
E_z_uniq
=
'argument z aliased to x or y'
E_z_uniq
=
'argument z aliased to x or y'
# TODO: justify / delete this
E_mixed
=
'gemm requires matching dtypes'
E_float
=
'gemm requires floating-point dtypes'
def
__init__
(
self
,
inplace
):
self
.
__setstate__
({
'inplace'
:
inplace
})
...
...
@@ -480,28 +592,45 @@ class Gemm(GemmRelated):
else
:
self
.
setup_z_Nz_Sz
=
self
.
setup_z_Nz_Sz_outplace
self
.
inplace
=
inplace
def
__getstate__
(
self
):
return
dict
(
inplace
=
self
.
inplace
)
def
make_node
(
self
,
*
inputs
):
inputs
=
map
(
T
.
as_tensor_variable
,
inputs
)
if
len
(
inputs
)
!=
5
:
raise
TypeError
(
"Wrong number of inputs for
%
s (expected 5, got
%
s)"
%
(
self
,
len
(
inputs
)))
z
,
a
,
x
,
y
,
b
=
inputs
zr
,
xr
,
yr
=
[
set
(
view_roots
(
i
))
for
i
in
z
,
x
,
y
]
# TODO: justify / delete
if
zr
.
intersection
(
xr
):
raise
InconsistencyError
(
Gemm
.
E_z_uniq
,
(
z
,
x
))
if
zr
.
intersection
(
yr
):
raise
InconsistencyError
(
Gemm
.
E_z_uniq
,
(
z
,
y
))
bz
,
ba
,
bx
,
by
,
bb
=
[
r
.
type
.
broadcastable
for
r
in
inputs
]
if
bz
!=
(
False
,
False
):
raise
ValueError
(
Gemm
.
E_rank
,
bz
)
if
bx
!=
(
False
,
False
):
raise
ValueError
(
Gemm
.
E_rank
,
bx
)
if
by
!=
(
False
,
False
):
raise
ValueError
(
Gemm
.
E_rank
,
by
)
if
len
(
ba
):
raise
ValueError
(
Gemm
.
E_scalar
,
ba
)
if
len
(
bb
):
raise
ValueError
(
Gemm
.
E_scalar
,
bb
)
if
z
.
ndim
!=
2
:
raise
TypeError
(
Gemm
.
E_rank
,
z
)
if
a
.
ndim
!=
0
:
raise
TypeError
(
Gemm
.
E_scalar
,
a
)
if
x
.
ndim
!=
2
:
raise
TypeError
(
Gemm
.
E_rank
,
x
)
if
y
.
ndim
!=
2
:
raise
TypeError
(
Gemm
.
E_rank
,
y
)
if
b
.
ndim
!=
0
:
raise
TypeError
(
Gemm
.
E_scalar
,
b
)
if
not
(
z
.
dtype
==
a
.
dtype
==
x
.
dtype
==
y
.
dtype
==
b
.
dtype
):
raise
TypeError
(
Gemm
.
E_mixed
,
(
z
.
dtype
,
a
.
dtype
,
x
.
dtype
,
y
.
dtype
,
b
.
dtype
))
if
(
not
z
.
dtype
.
startswith
(
'float'
)
and
not
z
.
dtype
.
startswith
(
'complex'
)):
raise
TypeError
(
Gemm
.
E_float
,
(
z
.
dtype
))
output
=
z
.
type
()
return
Apply
(
self
,
inputs
,
[
output
])
def
perform
(
self
,
node
,
inp
,
out
):
z
,
a
,
x
,
y
,
b
=
inp
zout
,
=
out
...
...
@@ -689,46 +818,51 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
if
M
.
type
.
broadcastable
!=
L
.
type
.
broadcastable
:
# GEMM cannot do the broadcasting that add used to be doing
# so abort.
return
assert
L
.
type
.
dtype
==
M
.
type
.
dtype
# because of local_dot_to_dot22
# we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1):
if
M
.
owner
and
M
.
owner
.
op
==
_dot22
:
if
M
.
broadcastable
==
L
.
broadcastable
:
Ml
,
Mr
=
M
.
owner
.
inputs
rval
=
[
gemm_no_inplace
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
#print 'GEMM 0', rval, beta, L, alpha, M
return
rval
if
M
.
owner
and
M
.
owner
.
op
==
T
.
dot
\
and
L
.
broadcastable
==
(
False
,)
\
and
M
.
broadcastable
==
(
False
,):
Ml
,
Mr
=
M
.
owner
.
inputs
rval
=
None
if
Ml
.
ndim
==
1
:
if
Mr
.
ndim
==
1
:
#TODO: insert a BLAS ddot Op
pass
if
Mr
.
ndim
==
2
:
#print "RETURNING GEMV (case 2)"
if
Mr
.
dtype
==
Ml
.
dtype
:
rval
=
[
gemv_no_inplace
(
L
,
alpha
,
Mr
.
T
,
Ml
,
beta
)]
assert
L
.
type
==
rval
[
0
]
.
type
,
(
L
.
type
,
rval
[
0
]
.
type
)
else
:
# TODO
pass
if
Ml
.
ndim
==
2
:
if
Mr
.
ndim
==
1
:
#print "RETURNING GEMV (case 3)"
if
Mr
.
dtype
==
Ml
.
dtype
:
rval
=
[
gemv_no_inplace
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
assert
L
.
type
==
rval
[
0
]
.
type
,
(
L
.
type
,
rval
[
0
]
.
type
)
else
:
# TODO
pass
if
Mr
.
ndim
==
2
:
# should have already got this case with a _dot22
pass
rval
=
[
gemm_no_inplace
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
#print 'GEMM 0', rval, beta, L, alpha, M
return
rval
# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
if
M
.
owner
and
isinstance
(
M
.
owner
.
op
,
T
.
DimShuffle
):
MM
=
M
.
owner
.
inputs
[
0
]
if
tuple
(
M
.
owner
.
op
.
new_order
)
==
(
0
,):
# it is making a column MM into a vector
if
MM
.
owner
and
MM
.
owner
.
op
==
_dot22
:
MMl
,
MMr
=
MM
.
owner
.
inputs
g
=
gemm_no_inplace
(
L
.
dimshuffle
(
0
,
'x'
),
alpha
,
MMl
,
MMr
,
beta
)
rval
=
[
g
.
dimshuffle
(
0
)]
return
rval
if
tuple
(
M
.
owner
.
op
.
new_order
)
==
(
1
,):
# it is making a row MM into a vector
if
MM
.
owner
and
MM
.
owner
.
op
==
_dot22
:
MMl
,
MMr
=
MM
.
owner
.
inputs
g
=
gemm_no_inplace
(
L
.
dimshuffle
(
'x'
,
0
),
alpha
,
MMl
,
MMr
,
beta
)
rval
=
[
g
.
dimshuffle
(
1
)]
return
rval
if
tuple
(
M
.
owner
.
op
.
new_order
)
==
():
# it is making a row MM into a vector
if
MM
.
owner
and
MM
.
owner
.
op
==
_dot22
:
MMl
,
MMr
=
MM
.
owner
.
inputs
g
=
gemm_no_inplace
(
L
.
dimshuffle
(
'x'
,
'x'
),
alpha
,
MMl
,
MMr
,
beta
)
rval
=
[
g
.
dimshuffle
()]
return
rval
# this is False'd out because of inadequate testing.
# TODO see ticket #237
if
False
and
res_is_a
(
M
,
gemm_no_inplace
,
1
):
...
...
@@ -881,6 +1015,10 @@ def _factor_canonicalized(lst):
def
_gemm_from_factored_list
(
lst
):
"""Returns None, or a list to replace node.outputs
"""
# Make every pair in list have matching dtypes
lst
=
[(
T
.
cast
(
si
,
Mi
.
type
.
dtype
),
Mi
)
for
si
,
Mi
in
lst
]
# Try every pair in the sM_list, trying to turn it into a gemm operation
for
i
in
xrange
(
len
(
lst
)
-
1
):
try
:
...
...
@@ -1052,6 +1190,8 @@ _dot22 = Dot22()
@local_optimizer
([
T
.
dot
])
def
local_dot_to_dot22
(
node
):
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
if
node
.
op
!=
T
.
dot
:
return
...
...
@@ -1060,16 +1200,20 @@ def local_dot_to_dot22(node):
# TODO: upcast one so the types match
_logger
.
info
(
'Not optimizing dot with inputs
%
s
%
s
%
s
%
s'
,
x
,
y
,
x
.
type
,
y
.
type
)
return
if
y
.
type
.
dtype
.
startswith
(
'float'
):
if
_is_real_matrix
(
x
)
and
_is_real_matrix
(
y
):
if
x
.
ndim
==
2
and
y
.
ndim
==
2
:
#print "local_dot_to_dot22: MM"
return
[
_dot22
(
*
node
.
inputs
)]
if
0
:
if
_is_real_matrix
(
x
)
and
_is_real_vector
(
y
):
return
[
_dot22
(
x
,
y
.
dimshuffle
(
0
,
'x'
))
.
dimshuffle
(
0
)]
if
_is_real_vector
(
x
)
and
_is_real_matrix
(
y
):
return
[
_dot22
(
x
.
dimshuffle
(
'x'
,
0
),
y
)
.
dimshuffle
(
1
)]
if
_is_real_vector
(
x
)
and
_is_real_vector
(
x
):
return
[
_dot22
(
x
.
dimshuffle
(
'x'
,
0
),
y
.
dimshuffle
(
0
,
'x'
))
.
dimshuffle
()]
if
x
.
ndim
==
2
and
y
.
ndim
==
1
:
#print "local_dot_to_dot22: MV"
return
[
_dot22
(
x
,
y
.
dimshuffle
(
0
,
'x'
))
.
dimshuffle
(
0
)]
if
x
.
ndim
==
1
and
y
.
ndim
==
2
:
#print "local_dot_to_dot22: VM"
return
[
_dot22
(
x
.
dimshuffle
(
'x'
,
0
),
y
)
.
dimshuffle
(
1
)]
if
x
.
ndim
==
1
and
y
.
ndim
==
1
:
#print "local_dot_to_dot22: VV"
return
[
_dot22
(
x
.
dimshuffle
(
'x'
,
0
),
y
.
dimshuffle
(
0
,
'x'
))
.
dimshuffle
()]
_logger
.
info
(
'Not optimizing dot with inputs
%
s
%
s
%
s
%
s'
,
x
,
y
,
x
.
type
,
y
.
type
)
...
...
@@ -1077,11 +1221,28 @@ def local_dot_to_dot22(node):
def
local_inplace_gemm
(
node
):
if
node
.
op
==
gemm_no_inplace
:
return
[
gemm_inplace
(
*
node
.
inputs
)]
@local_optimizer
([
gemv_no_inplace
])
def
local_inplace_gemv
(
node
):
if
node
.
op
==
gemv_no_inplace
:
return
[
gemv_inplace
(
*
node
.
inputs
)]
@local_optimizer
([
gemm_no_inplace
])
def
local_gemm_to_gemv
(
node
):
"""GEMM acting on row or column matrices -> GEMV
"""
if
node
.
op
==
gemm_no_inplace
:
z
,
a
,
x
,
y
,
b
=
node
.
inputs
if
z
.
broadcastable
==
x
.
broadcastable
==
(
True
,
False
):
r
=
gemv_no_inplace
(
z
.
dimshuffle
(
1
),
a
,
y
.
T
,
x
.
dimshuffle
(
1
),
b
)
return
[
r
.
dimshuffle
(
'x'
,
0
)]
if
z
.
broadcastable
==
y
.
broadcastable
==
(
False
,
True
):
r
=
gemv_no_inplace
(
z
.
dimshuffle
(
0
),
a
,
x
,
y
.
dimshuffle
(
0
),
b
)
return
[
r
.
dimshuffle
(
0
,
'x'
)]
#################################
#
# Set up the BlasOpt optimizer
...
...
@@ -1098,7 +1259,12 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
blas_optdb
.
register
(
'local_dot_to_dot22'
,
EquilibriumOptimizer
([
local_dot_to_dot22
],
max_use_ratio
=
5
),
0
,
'fast_run'
)
blas_optdb
.
register
(
'local_dot_to_gemm'
,
GemmOptimizer
(),
10
,
'fast_run'
)
blas_optdb
.
register
(
'local_dot_to_gemm'
,
GemmOptimizer
(),
10
,
'fast_run'
)
blas_optdb
.
register
(
'local_gemm_to_gemv'
,
EquilibriumOptimizer
([
local_gemm_to_gemv
],
max_use_ratio
=
5
),
15
,
'fast_run'
)
# After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace
...
...
@@ -1261,3 +1427,4 @@ from opt import register_specialize, register_canonicalize
def
local_print_as_we_go_along
(
node
):
if
node
.
op
in
(
T
.
sub
,
T
.
add
):
debugprint
(
node
)
theano/tensor/tests/test_blas.py
浏览文件 @
6f5a1844
...
...
@@ -83,7 +83,7 @@ class t_gemm(TestCase):
Gemm
.
debug
=
True
try
:
g
=
gemm_inplace
([
1.
],
1.
,
[
1.
],
[
1.
],
1.
)
except
Valu
eError
,
e
:
except
Typ
eError
,
e
:
if
e
[
0
]
is
Gemm
.
E_rank
:
return
self
.
fail
()
...
...
@@ -91,7 +91,7 @@ class t_gemm(TestCase):
def
test0
(
self
):
try
:
self
.
cmp
(
1.
,
0.
,
1.0
,
1.0
,
1.0
)
except
Valu
eError
,
e
:
except
Typ
eError
,
e
:
if
e
[
0
]
is
Gemm
.
E_rank
:
return
self
.
fail
()
...
...
@@ -99,7 +99,7 @@ class t_gemm(TestCase):
def
test2
(
self
):
try
:
self
.
cmp
(
2.
,
1.0
,
[
3
,
2
,
1.
],
[[
1
],[
2
],[
3.
]],
1.0
)
except
Valu
eError
,
e
:
except
Typ
eError
,
e
:
self
.
assertTrue
(
e
[
0
]
==
Gemm
.
E_rank
)
return
self
.
fail
()
...
...
@@ -124,14 +124,14 @@ class t_gemm(TestCase):
self
.
rand
(
3
,
5
),
self
.
rand
(
5
,
4
),
-
1.0
)
def
test_factorised_scalar
(
self
):
a
=
T
.
matrix
()
b
=
T
.
matrix
()
c
=
T
.
matrix
()
a
=
T
.
d
matrix
()
b
=
T
.
d
matrix
()
c
=
T
.
d
matrix
()
s
=
theano
.
shared
(
numpy
.
zeros
((
5
,
5
)))
lr1
=
T
.
constant
(
0.01
)
lr2
=
T
.
constant
(
2
)
l2_reg
=
T
.
constant
(
0.0001
)
lr1
=
T
.
constant
(
0.01
)
.
astype
(
'float64'
)
lr2
=
T
.
constant
(
2
)
.
astype
(
'float64'
)
l2_reg
=
T
.
constant
(
0.0001
)
.
astype
(
'float64'
)
#test constant merge with gemm
f
=
theano
.
function
([
a
,
b
],
updates
=
{
s
:
lr1
*
T
.
dot
(
a
,
b
)
+
l2_reg
*
lr2
*
s
},
mode
=
mode_not_fast_compile
)
.
maker
.
env
.
toposort
()
...
...
@@ -195,9 +195,10 @@ class t_gemm(TestCase):
"""test that dot args can be aliased"""
Z
=
shared
(
self
.
rand
(
2
,
2
))
A
=
shared
(
self
.
rand
(
2
,
2
))
f
=
inplace_func
([],
gemm_inplace
(
Z
,
1.0
,
A
,
A
,
1.0
))
one
=
T
.
constant
(
1.0
)
.
astype
(
Z
.
dtype
)
f
=
inplace_func
([],
gemm_inplace
(
Z
,
one
,
A
,
A
,
one
))
f
()
f
=
inplace_func
([],
gemm_inplace
(
Z
,
1.0
,
A
,
A
.
T
,
1.0
))
f
=
inplace_func
([],
gemm_inplace
(
Z
,
one
,
A
,
A
.
T
,
one
))
f
()
def
test_transposes
(
self
):
...
...
@@ -451,7 +452,8 @@ def test_gemm_opt_double_gemm():
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
(),
(
5
,
9
),
(
9
,
4
),
()]
i
=
[
X
,
Y
,
Z
,
a
,
b
,
R
,
S
,
c
]
o
=
[
a
*
T
.
dot
(
X
,
Y
)
+
gemm_inplace
(
Z
,
b
,
S
.
T
,
R
.
T
,
1.0
)]
o
=
[(
a
*
T
.
dot
(
X
,
Y
)
+
gemm_inplace
(
Z
,
b
,
S
.
T
,
R
.
T
,
T
.
constant
(
1.0
)
.
astype
(
'float64'
)))]
try
:
f
=
inplace_func
([
Param
(
ii
,
mutable
=
True
)
for
ii
in
i
],
o
,
mode
=
'FAST_RUN'
)
...
...
@@ -765,8 +767,9 @@ def test_dot_vm():
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
v
.
get_value
(),
m
.
get_value
()))
# Assert that the dot was optimized somehow
assert
sum
([
isinstance
(
node
.
op
,
T
.
Dot
)
for
node
in
f
.
maker
.
env
.
toposort
()
])
==
1
f
.
maker
.
env
.
toposort
()
])
==
0
def
test_dot_mv
():
''' Test matrix dot vector '''
...
...
@@ -779,8 +782,9 @@ def test_dot_mv():
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
m
.
get_value
(),
v
.
get_value
()))
# Assert that the dot was optimized somehow
assert
sum
([
isinstance
(
node
.
op
,
T
.
Dot
)
for
node
in
f
.
maker
.
env
.
toposort
()
])
==
1
f
.
maker
.
env
.
toposort
()
])
==
0
class
TestGemv
(
TestCase
):
def
test_gemv1
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论