Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc8822d3
提交
cc8822d3
authored
7月 22, 2011
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
several changes to tensor/blas.py
上级
dfdea75f
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
71 行增加
和
38 行删除
+71
-38
blas.py
theano/tensor/blas.py
+71
-38
没有找到文件。
theano/tensor/blas.py
浏览文件 @
cc8822d3
...
@@ -387,6 +387,12 @@ class GemmRelated(Op):
...
@@ -387,6 +387,12 @@ class GemmRelated(Op):
(long int)Ny[1], (long int)Nz[1]);
(long int)Ny[1], (long int)Nz[1]);
%(fail)
s;
%(fail)
s;
}
}
if (Nx[1] == 0)
{
PyErr_Format(PyExc_ValueError,
"Undefined semantics: x has 0 cols");
%(fail)
s;
}
"""
"""
check_strides
=
"""
check_strides
=
"""
...
@@ -497,6 +503,12 @@ class GemmRelated(Op):
...
@@ -497,6 +503,12 @@ class GemmRelated(Op):
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '
\\
n';
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '
\\
n';
//double t0 = time_time();
//double t0 = time_time();
//fprintf(stderr, "unit=
%%
x N=
%%
i
%%
i
%%
i S =
%%
i
%%
i
%%
i
%%
i
%%
i
%%
i
\\
n", unit,
//Nz1, Nz0, Nx1,
//sy_0, sy_1,
//sx_0, sx_1,
//sz_0, sz_1
//);
switch(unit)
switch(unit)
{
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
...
@@ -540,7 +552,7 @@ class GemmRelated(Op):
...
@@ -540,7 +552,7 @@ class GemmRelated(Op):
self
.
end_switch_typenum
),
''
)
self
.
end_switch_typenum
),
''
)
def
build_gemm_version
(
self
):
def
build_gemm_version
(
self
):
return
(
6
,)
return
(
7
,)
class
Gemm
(
GemmRelated
):
class
Gemm
(
GemmRelated
):
"""In-place version of matrix-matrix multiplication (with accumulation):
"""In-place version of matrix-matrix multiplication (with accumulation):
...
@@ -818,13 +830,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
...
@@ -818,13 +830,6 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M)
#EXPRESSION: (beta * L) + (alpha * M)
if
M
.
type
.
broadcastable
!=
L
.
type
.
broadcastable
:
# GEMM cannot do the broadcasting that add used to be doing
# so abort.
return
assert
L
.
type
.
dtype
==
M
.
type
.
dtype
# because of local_dot_to_dot22
# we've already checked the client counts, now just make the type check.
# we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1):
####if res_is_a(M, _dot22, 1):
if
M
.
owner
and
M
.
owner
.
op
==
_dot22
:
if
M
.
owner
and
M
.
owner
.
op
==
_dot22
:
...
@@ -1017,20 +1022,23 @@ def _gemm_from_factored_list(lst):
...
@@ -1017,20 +1022,23 @@ def _gemm_from_factored_list(lst):
"""
"""
# Make every pair in list have matching dtypes
# Make every pair in list have matching dtypes
lst
=
[(
T
.
cast
(
si
,
Mi
.
type
.
dtype
),
Mi
)
for
si
,
Mi
in
lst
]
def
is_pair
(
sM
):
try
:
s
,
M
=
sM
return
True
except
:
return
False
lst
=
[(
T
.
cast
(
sM
[
0
],
sM
[
1
]
.
type
.
dtype
),
sM
[
1
])
for
sM
in
lst
if
is_pair
(
sM
)]
# Try every pair in the sM_list, trying to turn it into a gemm operation
# Try every pair in the sM_list, trying to turn it into a gemm operation
for
i
in
xrange
(
len
(
lst
)
-
1
):
for
i
in
xrange
(
len
(
lst
)
-
1
):
try
:
s_i
,
M_i
=
lst
[
i
]
s_i
,
M_i
=
lst
[
i
]
except
:
continue
for
j
in
xrange
(
i
+
1
,
len
(
lst
)):
for
j
in
xrange
(
i
+
1
,
len
(
lst
)):
try
:
s_j
,
M_j
=
lst
[
j
]
s_j
,
M_j
=
lst
[
j
]
except
:
if
M_i
.
type
!=
M_j
.
type
:
continue
continue
#print 'TRYING', (s_i, M_i, s_j, M_j)
#print 'TRYING', (s_i, M_i, s_j, M_j)
...
@@ -1281,24 +1289,32 @@ class Dot22Scalar(GemmRelated):
...
@@ -1281,24 +1289,32 @@ class Dot22Scalar(GemmRelated):
Also used to generate a gemm later.
Also used to generate a gemm later.
compute scalar*dot(x,y)
compute scalar*dot(x,y)
"""
"""
def
make_node
(
self
,
x
,
y
,
scalar
):
def
make_node
(
self
,
x
,
y
,
a
):
if
not
_is_real_matrix
(
x
):
if
a
.
ndim
!=
0
:
raise
TypeError
(
x
)
raise
TypeError
(
Gemm
.
E_scalar
,
a
)
if
not
_is_real_matrix
(
x
):
if
x
.
ndim
!=
2
:
raise
TypeError
(
y
)
raise
TypeError
(
Gemm
.
E_rank
,
x
)
if
not
_as_scalar
(
scalar
):
if
y
.
ndim
!=
2
:
raise
TypeError
(
scalar
)
raise
TypeError
(
Gemm
.
E_rank
,
y
)
if
y
.
type
.
dtype
!=
x
.
type
.
dtype
and
y
.
type
.
dtype
!=
scalar
.
type
.
dtype
:
raise
TypeError
(
'dtype mismatch to Dot22Scalar'
)
if
not
(
a
.
dtype
==
x
.
dtype
==
y
.
dtype
):
bz
=
[
False
,
False
]
raise
TypeError
(
'Dot22Scalar requires matching dtypes'
,
(
a
.
dtype
,
x
.
dtype
,
y
.
dtype
))
if
(
not
a
.
dtype
.
startswith
(
'float'
)
and
not
a
.
dtype
.
startswith
(
'complex'
)):
raise
TypeError
(
'Dot22Scalar requires float or complex args'
,
a
.
dtype
)
bz
=
[
x
.
type
.
broadcastable
[
0
],
y
.
type
.
broadcastable
[
1
]]
outputs
=
[
T
.
tensor
(
x
.
type
.
dtype
,
bz
)]
outputs
=
[
T
.
tensor
(
x
.
type
.
dtype
,
bz
)]
return
Apply
(
self
,
[
x
,
y
,
scalar
],
outputs
)
return
Apply
(
self
,
[
x
,
y
,
a
],
outputs
)
def
perform
(
self
,
node
,
inp
,
out
):
def
perform
(
self
,
node
,
inp
,
out
):
x
,
y
,
scalar
=
inp
x
,
y
,
scalar
=
inp
z
,
=
out
z
,
=
out
try
:
try
:
z
[
0
]
=
scalar
*
numpy
.
asarray
(
numpy
.
dot
(
x
,
y
))
z
[
0
]
=
numpy
.
asarray
(
scalar
*
numpy
.
dot
(
x
,
y
))
except
ValueError
,
e
:
except
ValueError
,
e
:
# The error raised by numpy has no shape information, we mean to add that
# The error raised by numpy has no shape information, we mean to add that
e
.
args
=
e
.
args
+
(
x
.
shape
,
y
.
shape
)
e
.
args
=
e
.
args
+
(
x
.
shape
,
y
.
shape
)
...
@@ -1360,21 +1376,23 @@ def local_dot22_to_dot22scalar(node):
...
@@ -1360,21 +1376,23 @@ def local_dot22_to_dot22scalar(node):
return
False
return
False
i_dot22
=
[
x
.
owner
and
x
.
owner
.
op
==
_dot22
for
x
in
node
.
inputs
]
i_dot22
=
[
x
.
owner
and
x
.
owner
.
op
==
_dot22
for
x
in
node
.
inputs
]
if
not
any
(
i_dot22
):
return
False
# no dot22
if
not
any
(
i_dot22
):
return
False
# no dot22
if
i_dot22
.
count
(
True
)
>
1
:
return
False
#TODO fix
if
i_dot22
.
count
(
True
)
>
1
:
#we take the first _dot22 found. TODO check others!
#TODO: try each of them.
pass
#return False #TODO fix
dot22_idx
=
i_dot22
.
index
(
True
)
dot22_idx
=
i_dot22
.
index
(
True
)
d
=
node
.
inputs
[
dot22_idx
]
d
=
node
.
inputs
[
dot22_idx
]
i_scalar
=
[
_as_scalar
(
x
)
for
x
in
node
.
inputs
]
i_scalar
=
[
_as_scalar
(
x
)
for
x
in
node
.
inputs
]
if
not
any
(
i_scalar
)
and
not
any
([
x
.
owner
and
x
.
owner
.
op
==
T
.
mul
for
x
in
node
.
inputs
]):
if
not
any
(
i_scalar
):
i_mul
=
[
x
.
owner
and
x
.
owner
.
op
==
T
.
mul
for
x
in
node
.
inputs
]
if
not
any
(
i_mul
):
#no scalar in input and no multiplication
#no scalar in input and no multiplication
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
#if their was a multiplication we couls reorder the graph by the associativity of the graph.
return
False
return
False
if
not
any
(
i_scalar
):
#maybe we can reorder the graph as this mul have a mul in input.
#maybe we can reorder the graph as this mul have a mul in input.
#The canonizer should have merged those mul together.
#The canonizer should have merged those mul together.
#We support only 1 additional level of mul.
#We support only 1 additional level of mul.
i_mul
=
[
x
.
owner
and
x
.
owner
.
op
==
T
.
mul
for
x
in
node
.
inputs
]
mul_idx
=
i_mul
.
index
(
True
)
#we take the first mul!
mul_idx
=
i_mul
.
index
(
True
)
#we take the first mul!
m
=
node
.
inputs
[
mul_idx
]
m
=
node
.
inputs
[
mul_idx
]
...
@@ -1384,7 +1402,17 @@ def local_dot22_to_dot22scalar(node):
...
@@ -1384,7 +1402,17 @@ def local_dot22_to_dot22scalar(node):
if
_as_scalar
(
x
):
if
_as_scalar
(
x
):
scalar_idx
=
i
scalar_idx
=
i
break
break
dot
=
_dot22scalar
(
d
.
owner
.
inputs
[
0
],
d
.
owner
.
inputs
[
1
],
m
.
owner
.
inputs
[
scalar_idx
])
a
=
T
.
cast
(
_as_scalar
(
m
.
owner
.
inputs
[
scalar_idx
]),
d
.
type
.
dtype
)
assert
not
a
.
type
.
ndim
dot
=
_dot22scalar
(
d
.
owner
.
inputs
[
0
],
d
.
owner
.
inputs
[
1
],
a
)
# What about the other inputs to the original node that were
# neither part of the dot22 or this mul?
# I'm asserting there are no such inputs here:
assert
dot22_idx
!=
mul_idx
assert
all
((
i
in
(
dot22_idx
,
mul_idx
))
for
i
in
range
(
len
(
node
.
inputs
)))
return
[
T
.
mul
(
m
.
owner
.
inputs
[
1
-
i
],
dot
)]
return
[
T
.
mul
(
m
.
owner
.
inputs
[
1
-
i
],
dot
)]
elif
m
.
owner
and
m
.
owner
.
op
==
T
.
mul
:
elif
m
.
owner
and
m
.
owner
.
op
==
T
.
mul
:
...
@@ -1397,7 +1425,9 @@ def local_dot22_to_dot22scalar(node):
...
@@ -1397,7 +1425,9 @@ def local_dot22_to_dot22scalar(node):
scalar_idx
=
-
1
scalar_idx
=
-
1
for
i
,
x
in
enumerate
(
node
.
inputs
):
for
i
,
x
in
enumerate
(
node
.
inputs
):
if
i_scalar
[
i
]
and
theano
.
scalar
.
upcast
(
x
.
type
.
dtype
,
d
.
type
.
dtype
)
==
d
.
type
.
dtype
:
if
(
i_scalar
[
i
]
is
not
None
and
(
theano
.
scalar
.
upcast
(
x
.
type
.
dtype
,
d
.
type
.
dtype
)
==
d
.
type
.
dtype
)):
scalar_idx
=
i
scalar_idx
=
i
break
break
if
scalar_idx
<
0
:
if
scalar_idx
<
0
:
...
@@ -1405,15 +1435,18 @@ def local_dot22_to_dot22scalar(node):
...
@@ -1405,15 +1435,18 @@ def local_dot22_to_dot22scalar(node):
'of the scalar cannot be upcasted to the matrix type'
,
'of the scalar cannot be upcasted to the matrix type'
,
node
.
inputs
,
[
x
.
type
for
x
in
node
.
inputs
])
node
.
inputs
,
[
x
.
type
for
x
in
node
.
inputs
])
return
False
return
False
assert
scalar_idx
<
len
(
node
.
inputs
)
assert
scalar_idx
<
len
(
node
.
inputs
)
s
=
node
.
inputs
[
scalar_idx
]
s
=
node
.
inputs
[
scalar_idx
]
o
=
copy
.
copy
(
node
.
inputs
)
o
=
copy
.
copy
(
node
.
inputs
)
o
.
remove
(
d
)
o
.
remove
(
d
)
o
.
remove
(
s
)
o
.
remove
(
s
)
if
len
(
o
)
==
0
:
return
[
_dot22scalar
(
d
.
owner
.
inputs
[
0
],
d
.
owner
.
inputs
[
1
],
s
)]
a
=
T
.
cast
(
i_scalar
[
scalar_idx
],
d
.
type
.
dtype
)
assert
not
a
.
type
.
ndim
if
len
(
o
)
==
0
:
return
[
_dot22scalar
(
d
.
owner
.
inputs
[
0
],
d
.
owner
.
inputs
[
1
],
a
)]
else
:
else
:
return
[
T
.
mul
(
_dot22scalar
(
d
.
owner
.
inputs
[
0
],
d
.
owner
.
inputs
[
1
],
s
),
*
o
)]
return
[
T
.
mul
(
_dot22scalar
(
d
.
owner
.
inputs
[
0
],
d
.
owner
.
inputs
[
1
],
a
),
*
o
)]
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar
#must happen after gemm as the gemm optimizer don't understant dot22scalar and gemm give more speed up then dot22scalar
blas_optdb
.
register
(
'local_dot22_to_dot22scalar'
,
blas_optdb
.
register
(
'local_dot22_to_dot22scalar'
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论