Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
76a6cd53
提交
76a6cd53
authored
3月 09, 2010
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new GEMM optimization algorithm
上级
3128a44c
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
305 行增加
和
51 行删除
+305
-51
blas.py
theano/tensor/blas.py
+224
-44
test_blas.py
theano/tensor/tests/test_blas.py
+81
-7
没有找到文件。
theano/tensor/blas.py
浏览文件 @
76a6cd53
...
@@ -525,6 +525,9 @@ def _is_real_matrix(res):
...
@@ -525,6 +525,9 @@ def _is_real_matrix(res):
and
res
.
type
.
broadcastable
[
1
]
==
False
#cope with tuple vs. list
and
res
.
type
.
broadcastable
[
1
]
==
False
#cope with tuple vs. list
def
_as_isolated_scalar_times_matrix
(
res
):
def
_as_isolated_scalar_times_matrix
(
res
):
"""Returns (scalar_var, matrix_var) on success else None
"""
# isolated means that there is only one client of the result 'res'
if
res_is_a
(
res
,
T
.
mul
,
1
):
if
res_is_a
(
res
,
T
.
mul
,
1
):
if
len
(
res
.
owner
.
inputs
)
==
2
:
if
len
(
res
.
owner
.
inputs
)
==
2
:
L
,
R
=
res
.
owner
.
inputs
L
,
R
=
res
.
owner
.
inputs
...
@@ -546,6 +549,11 @@ def _as_isolated_scalar_times_matrix(res):
...
@@ -546,6 +549,11 @@ def _as_isolated_scalar_times_matrix(res):
else
:
else
:
return
None
return
None
if
len
(
matrices
)
==
1
:
if
len
(
matrices
)
==
1
:
if
len
(
scalars
)
==
0
:
rval
=
(
1.0
,
matrices
[
0
])
elif
len
(
scalars
)
==
1
:
rval
=
(
scalars
[
0
],
matrices
[
0
])
else
:
rval
=
(
T
.
mul
(
*
scalars
),
matrices
[
0
])
rval
=
(
T
.
mul
(
*
scalars
),
matrices
[
0
])
return
rval
return
rval
...
@@ -553,7 +561,9 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
...
@@ -553,7 +561,9 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
#EXPRESSION: (beta * L) + (alpha * M)
if
res_is_a
(
M
,
_dot22
,
1
):
# we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1):
if
M
.
owner
and
M
.
owner
.
op
==
_dot22
:
Ml
,
Mr
=
M
.
owner
.
inputs
Ml
,
Mr
=
M
.
owner
.
inputs
rval
=
[
gemm_no_inplace
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
rval
=
[
gemm_no_inplace
(
L
,
alpha
,
Ml
,
Mr
,
beta
)]
#print 'GEMM 0', rval, beta, L, alpha, M
#print 'GEMM 0', rval, beta, L, alpha, M
...
@@ -574,17 +584,14 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
...
@@ -574,17 +584,14 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
rval
=
[
gemm_no_inplace
(
gemm_no_inplace
(
L
,
alpha
*
b
,
x
,
y
,
beta
),
alpha
*
a
,
u
,
v
,
1.0
)]
rval
=
[
gemm_no_inplace
(
gemm_no_inplace
(
L
,
alpha
*
b
,
x
,
y
,
beta
),
alpha
*
a
,
u
,
v
,
1.0
)]
print
'GEMM 1'
,
rval
return
rval
return
rval
if
(
G
is
L
):
if
(
G
is
L
):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval
=
[
gemm_no_inplace
(
L
,
alpha
*
a
,
u
,
v
,
alpha
*
b
+
beta
)]
rval
=
[
gemm_no_inplace
(
L
,
alpha
*
a
,
u
,
v
,
alpha
*
b
+
beta
)]
print
'GEMM 2'
,
rval
return
rval
return
rval
if
(
1.0
!=
alpha
):
if
(
1.0
!=
alpha
):
#at the very least, move the alpha inside the gemm_no_inplace
#at the very least, move the alpha inside the gemm_no_inplace
rval
=
[
beta
*
L
+
gemm_no_inplace
(
G
,
alpha
*
a
,
u
,
v
,
alpha
*
b
)]
rval
=
[
beta
*
L
+
gemm_no_inplace
(
G
,
alpha
*
a
,
u
,
v
,
alpha
*
b
)]
print
'GEMM 3'
,
rval
return
rval
return
rval
if
recurse_flip
:
if
recurse_flip
:
...
@@ -592,43 +599,174 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
...
@@ -592,43 +599,174 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
else
:
else
:
return
False
return
False
def
_gemm_from_node
(
node
):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
"""
def
_gemm_canonicalize
(
r
,
scale
,
rval
,
maxclients
):
if
node
.
op
==
T
.
sub
:
# Tries to interpret node as a sum of scalars * matrices
L
,
R
=
node
.
inputs
def
scaled
(
thing
):
if
not
_is_real_matrix
(
L
):
if
scale
==
1
:
return
False
return
thing
if
not
_is_real_matrix
(
R
):
if
scale
==
-
1
:
return
False
return
-
thing
else
:
return
scale
*
thing
if
(
tuple
(
r
.
type
.
broadcastable
)
!=
(
False
,
False
)
or
r
.
type
.
dtype
not
in
(
'float32'
,
'float64'
,
'complex64'
,
'complex128'
)):
rval
.
append
(
scaled
(
r
))
return
rval
if
maxclients
and
len
(
getattr
(
r
,
'clients'
,[]))
>
maxclients
:
rval
.
append
((
scale
,
r
))
return
rval
if
r
.
owner
and
r
.
owner
.
op
==
T
.
sub
:
_gemm_canonicalize
(
r
.
owner
.
inputs
[
0
],
scale
,
rval
,
1
)
_gemm_canonicalize
(
r
.
owner
.
inputs
[
1
],
-
scale
,
rval
,
1
)
elif
r
.
owner
and
r
.
owner
.
op
==
T
.
add
:
for
i
in
r
.
owner
.
inputs
:
_gemm_canonicalize
(
i
,
scale
,
rval
,
1
)
elif
r
.
owner
and
r
.
owner
.
op
==
T
.
neg
:
_gemm_canonicalize
(
r
.
owner
.
inputs
[
0
],
-
scale
,
rval
,
1
)
elif
r
.
owner
and
r
.
owner
.
op
==
T
.
mul
:
scalars
=
[]
matrices
=
[]
for
i
in
r
.
owner
.
inputs
:
if
numpy
.
all
(
i
.
type
.
broadcastable
):
while
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
T
.
DimShuffle
):
i
=
i
.
owner
.
inputs
[
0
]
if
i
.
type
.
broadcastable
:
scalars
.
append
(
i
.
dimshuffle
())
else
:
scalars
.
append
(
i
)
elif
_is_real_matrix
(
i
):
matrices
.
append
(
i
)
else
:
# just put the original arguments as in the base case
rval
.
append
((
scale
,
r
))
return
rval
if
len
(
matrices
)
==
1
:
m
=
matrices
[
0
]
if
len
(
scalars
)
==
0
:
_gemm_canonicalize
(
m
,
scale
,
rval
,
1
)
elif
len
(
scalars
)
==
1
:
_gemm_canonicalize
(
m
,
scaled
(
scalars
[
0
]),
rval
,
1
)
else
:
_gemm_canonicalize
(
m
,
T
.
mul
(
scaled
(
scalars
[
0
]),
*
scalars
[
1
:]),
rval
,
1
)
else
:
#there are many matrices... lets not open this up
rval
.
append
((
scale
,
r
))
else
:
rval
.
append
((
scale
,
r
))
return
rval
tmp
=
_as_isolated_scalar_times_matrix
(
L
)
def
_factor_canonicalized
(
lst
):
# remove duplicates from canonicalized list
# we only delete out of the right end of the list,
# once i has touched a list element, it is permantent
lst
=
list
(
lst
)
#print 'FACTOR', lst
#for (a,b) in lst:
#theano.printing.debugprint(a)
#theano.printing.debugprint(b)
i
=
0
while
i
<
len
(
lst
)
-
1
:
try
:
try
:
s
L
,
mL
=
tmp
s
_i
,
M_i
=
lst
[
i
]
except
:
except
:
sL
,
mL
=
1.0
,
L
i
+=
1
continue
j
=
i
+
1
while
j
<
len
(
lst
):
try
:
s_j
,
M_j
=
lst
[
j
]
except
:
j
+=
1
continue
if
M_i
is
M_j
:
s_i
=
s_i
+
s_j
lst
[
i
]
=
(
s_i
,
M_i
)
del
lst
[
j
]
else
:
j
+=
1
i
+=
1
return
lst
def
_gemm_from_factored_list
(
lst
):
"""Returns None, or a list to replace node.outputs
"""
# Try every pair in the sM_list, trying to turn it into a gemm operation
for
i
in
xrange
(
len
(
lst
)
-
1
):
try
:
s_i
,
M_i
=
lst
[
i
]
except
:
continue
for
j
in
xrange
(
i
+
1
,
len
(
lst
)):
tmp
=
_as_isolated_scalar_times_matrix
(
R
)
try
:
try
:
sR
,
mR
=
tmp
s_j
,
M_j
=
lst
[
j
]
except
:
except
:
sR
,
mR
=
1.0
,
R
continue
rval
=
_beta_L_plus_alpha_M
(
sL
,
mL
,
-
sR
,
mR
)
#print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list
=
_beta_L_plus_alpha_M
(
s_i
,
M_i
,
s_j
,
M_j
)
if
gemm_of_sM_list
:
#print 'GOT IT', gemm_of_sM_list
def
item_to_var
(
t
):
try
:
s
,
M
=
t
except
:
return
t
if
s
==
1
:
return
M
if
s
==
-
1
:
return
-
M
return
s
*
M
assert
len
(
gemm_of_sM_list
)
==
1
add_inputs
=
[
item_to_var
(
input
)
for
k
,
input
in
enumerate
(
lst
)
if
k
not
in
(
i
,
j
)]
add_inputs
.
extend
(
gemm_of_sM_list
)
if
len
(
add_inputs
)
>
1
:
return
[
T
.
add
(
*
add_inputs
)]
else
:
return
add_inputs
def
_gemm_from_node2
(
node
):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
"""
lst
=
[]
_gemm_canonicalize
(
node
.
outputs
[
0
],
1.0
,
lst
,
0
)
if
len
(
lst
)
>
1
:
lst
=
_factor_canonicalized
(
lst
)
rval
=
_gemm_from_factored_list
(
lst
)
return
rval
return
rval
if
node
.
op
==
T
.
add
:
# arguments of the form scalar * matrix
sM_list
=
[]
# arguments that can be interpreted as scalar * matrix
sM_orig
=
[]
# arguments not of the form scalar * matrix (i.e., vectors, scalars)
def
inputs_as_scalar_times_matrix
(
node
):
other_inputs
=
[]
# try to interpret an expression as a sum of scalar * matrix terms plus an 'other' term.
# This function *could* recurse and flatten sub and add hierarchies, but it doesn't.
# Reason being - if we didn't need intermediate results, the canonizer should already done
# that.
# returns three lists: sM_list, sM_orig, other
# - sM_list is a list of pairs: the interpretation of some terms as scalar,matrix products
# - sM_orig is a list of variables: the originals before interpretation into sM_list
# - other is a list of terms that are not float matrices
op
=
None
sM_list
=
[]
sM_orig
=
[]
other
=
[]
if
node
.
op
==
T
.
add
or
node
.
op
==
T
.
sub
:
op
=
node
.
op
for
input
in
node
.
inputs
:
for
input
in
node
.
inputs
:
tmp
=
_as_isolated_scalar_times_matrix
(
input
)
tmp
=
_as_isolated_scalar_times_matrix
(
input
)
if
tmp
:
if
tmp
:
...
@@ -638,10 +776,16 @@ def _gemm_from_node(node):
...
@@ -638,10 +776,16 @@ def _gemm_from_node(node):
sM_list
.
append
((
1.0
,
input
))
sM_list
.
append
((
1.0
,
input
))
sM_orig
.
append
(
input
)
sM_orig
.
append
(
input
)
else
:
else
:
other
_inputs
.
append
(
input
)
other
.
append
(
input
)
assert
len
(
sM_list
)
==
len
(
sM_orig
)
assert
len
(
sM_list
)
==
len
(
sM_orig
)
assert
len
(
sM_list
)
+
len
(
other_inputs
)
==
len
(
node
.
inputs
)
assert
len
(
sM_list
)
+
len
(
other
)
==
len
(
node
.
inputs
)
return
op
,
sM_list
,
sM_orig
,
other
def
_gemm_from_sM_list
(
node
,
sM_list
,
sM_orig
,
other_inputs
):
"""Returns None, or a list to replace node.outputs
"""
if
len
(
sM_list
)
==
2
:
if
len
(
sM_list
)
==
2
:
(
sL
,
mL
),
(
sR
,
mR
)
=
sM_list
(
sL
,
mL
),
(
sR
,
mR
)
=
sM_list
gemm_of_sM_list
=
_beta_L_plus_alpha_M
(
sL
,
mL
,
sR
,
mR
)
gemm_of_sM_list
=
_beta_L_plus_alpha_M
(
sL
,
mL
,
sR
,
mR
)
...
@@ -666,21 +810,56 @@ def _gemm_from_node(node):
...
@@ -666,21 +810,56 @@ def _gemm_from_node(node):
new_add_inputs
=
(
inputs_without_ij
+
gemm_of_sM_list
+
other_inputs
)
new_add_inputs
=
(
inputs_without_ij
+
gemm_of_sM_list
+
other_inputs
)
if
False
:
#SUPER DEBUG MODE :(
if
len
(
new_add_inputs
)
+
1
!=
len
(
node
.
inputs
):
print
'inputs'
,
node
.
inputs
print
'sM, other'
,
sM_list
,
other_inputs
print
'i,j'
,
i
,
j
print
'gemm'
,
gemm_of_sM_list
print
'without ij'
,
inputs_without_ij
print
'new inputs'
,
new_add_inputs
sys
.
exit
(
1
)
# this should be True because we've combined a pair of arguments
# this should be True because we've combined a pair of arguments
# into a single GEMM
# into a single GEMM
assert
len
(
new_add_inputs
)
+
1
==
len
(
node
.
inputs
)
assert
len
(
new_add_inputs
)
+
1
==
len
(
node
.
inputs
)
return
[
T
.
add
(
*
new_add_inputs
)]
return
[
T
.
add
(
*
new_add_inputs
)]
return
False
def
_gemm_from_node
(
node
):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
"""
op
,
sM_list
,
sM_orig
,
other_inputs
=
inputs_as_scalar_times_matrix
(
node
)
if
op
==
T
.
sub
and
len
(
sM_list
)
==
2
:
(
sL
,
mL
),
(
sR
,
mR
)
=
sM_list
rval
=
_gemm_from_sM_list
([(
sL
,
mL
),
(
-
sR
,
mR
)],
None
,
None
)
if
rval
:
return
rval
#theano.printing.debugprint(node.outputs[0], depth=6)
if
len
(
sM_orig
[
1
]
.
clients
)
==
1
:
# Canonicalize this subgraph
# There is a form of Gemm that escapes the approach above
# g*W - (a * (e*dot(b,c) + d * W + X))
#
# -> gemm(W, -a*e, b, c, g-a*d) - a*X
#
# In this case g=sL W=mL, and a=sR. We must see if mR is a add() or a sub, in which
# one of the arguments is a scaled version of W a.k.a mL
Rop
,
RsM_list
,
RsM_orig
,
Rother_inputs
=
inputs_as_scalar_times_matrix
(
mR
.
owner
)
RsM_list_that_is_mL
=
[
s
for
(
s
,
m
)
in
RsM_list
if
m
is
mL
]
if
RsM_list_that_is_mL
and
Rop
==
T
.
add
:
pass
#g= sL - T.mul(sR,*RsM_list_that_is_mL)
#rval = _gemm_from_sM_list(
#[(g,mL)] + []]
#]
#)
#if Rop == T.add:
#rval = _beta_L_plus_alpha_M(
#L=mL,
#alpha=sR,
#R=T.)
return
rval
if
op
==
T
.
add
:
return
_gemm_from_sM_list
(
sM_list
,
sM_orig
,
other_inputs
)
class
GemmOptimizer
(
Optimizer
):
class
GemmOptimizer
(
Optimizer
):
"""Graph optimizer for inserting Gemm operations"""
"""Graph optimizer for inserting Gemm operations"""
...
@@ -698,7 +877,8 @@ class GemmOptimizer(Optimizer):
...
@@ -698,7 +877,8 @@ class GemmOptimizer(Optimizer):
did_something
=
False
did_something
=
False
nodelist
.
reverse
()
nodelist
.
reverse
()
for
node
in
nodelist
:
for
node
in
nodelist
:
new_outputs
=
_gemm_from_node
(
node
)
#new_outputs = _gemm_from_node(node)
new_outputs
=
_gemm_from_node2
(
node
)
if
new_outputs
:
if
new_outputs
:
assert
len
(
new_outputs
)
==
len
(
node
.
outputs
)
assert
len
(
new_outputs
)
==
len
(
node
.
outputs
)
try
:
try
:
...
...
theano/tensor/tests/test_blas.py
浏览文件 @
76a6cd53
from
nose.plugins.skip
import
SkipTest
import
traceback
import
traceback
import
theano.tensor
as
T
import
theano.tensor
as
T
from
theano.gof
import
Env
from
theano.gof
import
Env
from
theano.printing
import
pp
from
theano.printing
import
pp
import
numpy
,
theano
import
numpy
,
theano
from
theano.tensor.blas
import
*
from
theano.tensor.blas
import
*
from
theano.tensor.blas
import
_dot22
,
_dot22scalar
,
res_is_a
,
_as_scalar
,
_is_real_matrix
from
theano.tensor.blas
import
(
_dot22
,
_dot22scalar
,
res_is_a
,
_as_scalar
,
_is_real_matrix
,
_gemm_canonicalize
,
_factor_canonicalized
)
from
unittest
import
TestCase
from
unittest
import
TestCase
from
theano.tests
import
unittest_tools
from
theano.tests
import
unittest_tools
from
copy
import
copy
from
copy
import
copy
...
@@ -267,16 +269,24 @@ class Failure(Exception):
...
@@ -267,16 +269,24 @@ class Failure(Exception):
class
Warning
(
Exception
):
class
Warning
(
Exception
):
pass
pass
def
just_gemm
(
i
,
o
,
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
()]):
def
just_gemm
(
i
,
o
,
ishapes
=
[(
4
,
3
),
(
3
,
5
),
(
4
,
5
),
(),
()]
,
max_graphlen
=
0
):
try
:
try
:
f
=
inplace_func
([
Param
(
ii
,
mutable
=
True
)
for
ii
in
i
],
o
,
mode
=
'FAST_RUN'
)
f
=
inplace_func
([
Param
(
ii
,
mutable
=
True
)
for
ii
in
i
],
o
,
mode
=
'FAST_RUN'
)
at_least_one_gemm
=
False
for
node
in
f
.
maker
.
env
.
nodes
:
for
node
in
f
.
maker
.
env
.
nodes
:
if
node
.
op
==
T
.
dot
:
raise
Warning
(
'dot not changed to gemm_inplace in graph'
)
if
node
.
op
==
T
.
dot
:
raise
Warning
(
'dot not changed to gemm_inplace in graph'
)
if
node
.
op
==
_dot22
:
raise
Warning
(
'_dot22 not changed to gemm_inplace in graph'
)
if
node
.
op
==
_dot22
:
raise
Warning
(
'_dot22 not changed to gemm_inplace in graph'
)
if
node
.
op
==
gemm_inplace
:
at_least_one_gemm
=
True
assert
at_least_one_gemm
g
=
inplace_func
(
i
,
o
,
mode
=
compile
.
Mode
(
linker
=
'py'
,
optimizer
=
None
))
g
=
inplace_func
(
i
,
o
,
mode
=
compile
.
Mode
(
linker
=
'py'
,
optimizer
=
None
))
for
node
in
g
.
maker
.
env
.
nodes
:
for
node
in
g
.
maker
.
env
.
nodes
:
if
node
.
op
==
gemm_inplace
:
raise
Exception
(
'gemm_inplace in original graph'
)
if
node
.
op
==
gemm_inplace
:
raise
Exception
(
'gemm_inplace in original graph'
)
graphlen
=
len
(
f
.
maker
.
env
.
toposort
())
if
max_graphlen
and
(
graphlen
<=
max_graphlen
):
theano
.
printing
.
debugprint
(
f
)
assert
False
,
'graphlen=
%
i>
%
i'
%
(
graphlen
,
max_graphlen
)
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
(
234
))
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
(
234
))
r0
=
f
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
r0
=
f
(
*
[
rng
.
randn
(
*
sh
)
for
sh
in
ishapes
])
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
(
234
))
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
(
234
))
...
@@ -353,12 +363,76 @@ def test_gemm_opt_double_gemm():
...
@@ -353,12 +363,76 @@ def test_gemm_opt_double_gemm():
print
'GRAPH'
,
node
print
'GRAPH'
,
node
raise
raise
def
wishlist_gemm_opt
():
def
test_gemm_canonicalize
():
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(
'X'
),
T
.
dmatrix
(
'Y'
),
T
.
dmatrix
(
'Z'
),
T
.
dscalar
(
'a'
),
T
.
dscalar
(
'b'
)
R
,
S
,
U
,
c
,
d
=
T
.
dmatrix
(
'R'
),
T
.
dmatrix
(
'S'
),
T
.
dmatrix
(
'U'
),
T
.
dscalar
(
'c'
),
T
.
dscalar
(
'd'
)
u
=
T
.
row
(
'u'
)
can
=
[]
_gemm_canonicalize
(
X
+
Y
+
Z
,
1.0
,
can
,
0
)
assert
can
==
[(
1.0
,
X
),
(
1.0
,
Y
),
(
1.0
,
Z
)]
can
=
[]
_gemm_canonicalize
(
X
+
Y
+
u
,
1.0
,
can
,
0
)
assert
can
==
[(
1.0
,
X
),
(
1.0
,
Y
),
u
]
can
=
[]
_gemm_canonicalize
(
a
*
X
+
Y
-
b
*
Z
*
c
,
1.0
,
can
,
0
)
assert
can
[
0
]
==
(
a
,
X
)
assert
can
[
1
]
==
(
1.0
,
Y
)
assert
can
[
2
][
0
]
.
owner
.
op
==
T
.
mul
assert
can
[
2
][
0
]
.
owner
.
inputs
[
0
]
.
owner
.
op
==
T
.
neg
assert
can
[
2
][
0
]
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
==
c
assert
can
[
2
][
0
]
.
owner
.
inputs
[
1
]
==
b
can
=
[]
_gemm_canonicalize
((
-
d
)
*
X
-
(
a
*
X
+
Y
-
b
*
Z
*
c
),
1.0
,
can
,
0
)
print
can
assert
can
[
0
][
0
]
.
owner
.
op
==
T
.
neg
assert
can
[
0
][
0
]
.
owner
.
inputs
[
0
]
==
d
assert
can
[
0
][
1
]
==
X
assert
can
[
1
][
0
]
.
owner
.
op
==
T
.
neg
assert
can
[
1
][
0
]
.
owner
.
inputs
[
0
]
==
a
assert
can
[
2
]
==
(
-
1.0
,
Y
)
assert
can
[
3
][
0
]
.
owner
.
op
==
T
.
mul
assert
can
[
3
][
0
]
.
owner
.
inputs
==
[
c
,
b
]
def
test_gemm_factor
():
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(
'X'
),
T
.
dmatrix
(
'Y'
),
T
.
dmatrix
(
'Z'
),
T
.
dscalar
(
'a'
),
T
.
dscalar
(
'b'
)
R
,
S
,
U
,
c
,
d
=
T
.
dmatrix
(
'R'
),
T
.
dmatrix
(
'S'
),
T
.
dmatrix
(
'U'
),
T
.
dscalar
(
'c'
),
T
.
dscalar
(
'd'
)
u
=
T
.
row
(
'u'
)
assert
[(
1.0
,
X
),
(
1.0
,
Y
),
u
]
==
_factor_canonicalized
([(
1.0
,
X
),
(
1.0
,
Y
),
u
])
assert
[(
2.0
,
X
),
u
]
==
_factor_canonicalized
([(
1.0
,
X
),(
1.0
,
X
),
u
])
def
test_gemm_nested
():
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(
'X'
),
T
.
dmatrix
(
'Y'
),
T
.
dmatrix
(
'Z'
),
T
.
dscalar
(
'a'
),
T
.
dscalar
(
'b'
)
R
,
S
,
U
,
c
,
d
=
T
.
dmatrix
(
'R'
),
T
.
dmatrix
(
'S'
),
T
.
dmatrix
(
'U'
),
T
.
dscalar
(
'c'
),
T
.
dscalar
(
'd'
)
u
=
T
.
row
(
'u'
)
just_gemm
([
X
,
Y
,
Z
,
R
,
S
,
U
,
a
,
b
,
c
,
d
],
[
a
*
Z
-
b
*
(
c
*
T
.
dot
(
X
,
Y
)
+
d
*
Z
)],
ishapes
=
[(
2
,
3
),(
3
,
4
),(
2
,
4
),(
2
,
3
),(
3
,
4
),(
2
,
4
),(),(),(),()],
max_graphlen
=
1
)
print
"---------------------"
just_gemm
([
X
,
Y
,
Z
,
R
,
S
,
U
,
a
,
b
,
c
,
d
],
[
a
*
Z
-
b
*
(
c
*
T
.
dot
(
X
,
Y
)
+
d
*
Z
+
c
*
Z
)],
ishapes
=
[(
2
,
3
),(
3
,
4
),(
2
,
4
),(
2
,
3
),(
3
,
4
),(
2
,
4
),(),(),(),()],
max_graphlen
=
1
)
print
"---------------------"
just_gemm
([
X
,
Y
,
Z
,
R
,
S
,
U
,
a
,
b
,
c
,
d
],
[
a
*
Z
-
b
*
(
c
*
T
.
dot
(
X
,
Y
)
+
d
*
Z
+
c
*
U
)],
ishapes
=
[(
2
,
3
),(
3
,
4
),(
2
,
4
),(
2
,
3
),(
3
,
4
),(
2
,
4
),(),(),(),()],
max_graphlen
=
3
)
def
test_gemm_opt_wishlist
():
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
X
,
Y
,
Z
,
a
,
b
=
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dmatrix
(),
T
.
dscalar
(),
T
.
dscalar
()
#with >2 additions of the same T.dot(X,Y term
#with >2 additions of the same T.dot(X,Y term
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
T
.
dot
(
X
,
Y
)
+
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[(
b
*
b
)
*
Z
*
a
+
(
a
*
a
)
*
T
.
dot
(
X
,
Y
)
+
b
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
T
.
dot
(
X
,
Y
)
+
T
.
dot
(
X
,
Y
)])
def
test_gemm_with_vector
():
def
test_gemm_with_vector
():
"""Many subgraphs whose dots can be eliminated.
"""Many subgraphs whose dots can be eliminated.
...
@@ -423,9 +497,9 @@ def test_inplace1():
...
@@ -423,9 +497,9 @@ def test_inplace1():
# with > 2 terms in the overall addition
# with > 2 terms in the overall addition
f
=
inplace_func
([
X
,
Y
,
Z
,
a
,
b
],
f
=
inplace_func
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
+
Z
+
T
.
dot
(
X
,
Y
)],
mode
=
'FAST_RUN'
)
[
Z
+
Z
+
T
.
dot
(
X
,
Y
)],
mode
=
'FAST_RUN'
)
# gemm_inplace should operate in-place on (Z+Z
)
theano
.
printing
.
debugprint
(
f
)
if
(
not
gemm_inplace
in
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
]):
# it doesn't work inplace because we didn't mark Z as mutable input
raise
Failure
(
'no gemm_inplace in graph'
)
assert
[
n
.
op
for
n
in
f
.
maker
.
env
.
nodes
]
==
[
gemm_no_inplace
]
def
test_dot22
():
def
test_dot22
():
if
config
.
mode
==
'FAST_COMPILE'
:
if
config
.
mode
==
'FAST_COMPILE'
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论