Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f2743791
提交
f2743791
authored
1月 25, 2012
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #375 from nouiz/c_blas
additions and improvements to blas on both CPU and GPU (including GEMV, GER in blas_c)
上级
d9c41f58
03e7a048
隐藏空白字符变更
内嵌
并排
正在显示
14 个修改的文件
包含
1468 行增加
和
296 行删除
+1468
-296
blas.py
theano/sandbox/cuda/blas.py
+182
-0
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+127
-8
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+1
-0
opt.py
theano/sandbox/cuda/opt.py
+80
-20
test_blas.py
theano/sandbox/cuda/tests/test_blas.py
+41
-0
__init__.py
theano/tensor/__init__.py
+1
-0
basic.py
theano/tensor/basic.py
+35
-0
blas.py
theano/tensor/blas.py
+75
-52
blas_c.py
theano/tensor/blas_c.py
+466
-0
blas_scipy.py
theano/tensor/blas_scipy.py
+1
-4
test_blas.py
theano/tensor/tests/test_blas.py
+157
-209
test_blas_c.py
theano/tensor/tests/test_blas_c.py
+219
-0
test_blas_scipy.py
theano/tensor/tests/test_blas_scipy.py
+7
-3
unittest_tools.py
theano/tests/unittest_tools.py
+76
-0
没有找到文件。
theano/sandbox/cuda/blas.py
浏览文件 @
f2743791
...
@@ -251,6 +251,188 @@ class GpuGemm(Op):
...
@@ -251,6 +251,188 @@ class GpuGemm(Op):
gpu_gemm_no_inplace
=
GpuGemm
(
inplace
=
False
)
gpu_gemm_no_inplace
=
GpuGemm
(
inplace
=
False
)
gpu_gemm_inplace
=
GpuGemm
(
inplace
=
True
)
gpu_gemm_inplace
=
GpuGemm
(
inplace
=
True
)
class
GpuGemv
(
Op
):
"""
implement gemv on the gpu.
"""
def
__init__
(
self
,
inplace
):
self
.
__setstate__
({
'inplace'
:
inplace
})
def
__str__
(
self
):
if
self
.
inplace
:
return
'GpuGemv{inplace}'
else
:
return
'GpuGemv{no_inplace}'
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
\
and
self
.
inplace
==
other
.
inplace
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
^
hash
(
self
.
inplace
)
def
__setstate__
(
self
,
dct
):
inplace
=
dct
.
get
(
'inplace'
,
True
)
if
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
self
.
inplace
=
inplace
def
__getstate__
(
self
):
return
dict
(
inplace
=
self
.
inplace
)
def
make_node
(
self
,
z
,
a
,
x
,
y
,
b
):
# the more complicated error checking performed by tensor.gemv is assumed to already
# have been done
return
Apply
(
self
,
[
z
,
a
,
x
,
y
,
b
],
[
z
.
type
()])
def
c_code_cache_version
(
self
):
return
(
1
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
#z_out = alpha * dot(x,y) + beta * z_in
#inplace version, set set z_out = z_in
#not inplace version, we copy z_in to z_out.
z_in
,
a
,
x
,
y
,
b
=
inputs
z_out
,
=
outputs
fail
=
sub
[
'fail'
]
sio
=
StringIO
.
StringIO
()
print
>>
sio
,
"""
float
%(name)
s_alpha = ((dtype_
%(a)
s*)(
%(a)
s->data))[0];
float
%(name)
s_beta = ((dtype_
%(b)
s*)(
%(b)
s->data))[0];
"""
if
self
.
inplace
:
print
>>
sio
,
"""
Py_XDECREF(
%(z_out)
s);
%(z_out)
s =
%(z_in)
s;
Py_INCREF(
%(z_out)
s);
"""
else
:
print
>>
sio
,
"""
if (!
%(z_out)
s
|| (
%(z_out)
s->nd != 1)
|| (CudaNdarray_HOST_DIMS(
%(z_out)
s)[0] != CudaNdarray_HOST_DIMS(
%(z_in)
s)[0])
)
{
Py_XDECREF(
%(z_out)
s);
%(z_out)
s = (CudaNdarray*)CudaNdarray_Copy(
%(z_in)
s);
if (!
%(z_out)
s)
{
%(fail)
s;
}
}
else
{
if (CudaNdarray_CopyFromCudaNdarray(
%(z_out)
s,
%(z_in)
s))
{
%(fail)
s;
}
}
"""
print
>>
sio
,
"""
if (CudaNdarray_sgemv(
%(name)
s_alpha,
%(x)
s,
%(y)
s,
%(name)
s_beta,
%(z_out)
s))
{
%(fail)
s;
}
"""
return
sio
.
getvalue
()
%
locals
()
gpu_gemv_no_inplace
=
GpuGemv
(
inplace
=
False
)
gpu_gemv_inplace
=
GpuGemv
(
inplace
=
True
)
class
GpuGer
(
Op
):
"""
implement ger on the gpu.
"""
def
__init__
(
self
,
inplace
):
self
.
__setstate__
({
'inplace'
:
inplace
})
def
__str__
(
self
):
if
self
.
inplace
:
return
'GpuGer{inplace}'
else
:
return
'GpuGer{no_inplace}'
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
)
\
and
self
.
inplace
==
other
.
inplace
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
^
hash
(
self
.
inplace
)
def
__setstate__
(
self
,
dct
):
inplace
=
dct
.
get
(
'inplace'
,
True
)
if
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
self
.
inplace
=
inplace
def
__getstate__
(
self
):
return
dict
(
inplace
=
self
.
inplace
)
def
make_node
(
self
,
z
,
a
,
x
,
y
):
# the more complicated error checking performed by tensor.ger is
# assumed to already have been done
return
Apply
(
self
,
[
z
,
a
,
x
,
y
],
[
z
.
type
()])
def
c_code_cache_version
(
self
):
return
(
1
,)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
#z_out = alpha * dot(x,y) + beta * z_in
#inplace version, set set z_out = z_in
#not inplace version, we copy z_in to z_out.
z_in
,
a
,
x
,
y
=
inputs
z_out
,
=
outputs
fail
=
sub
[
'fail'
]
sio
=
StringIO
.
StringIO
()
print
>>
sio
,
"""
float
%(name)
s_alpha = ((dtype_
%(a)
s*)(
%(a)
s->data))[0];
"""
if
self
.
inplace
:
print
>>
sio
,
"""
Py_XDECREF(
%(z_out)
s);
%(z_out)
s =
%(z_in)
s;
Py_INCREF(
%(z_out)
s);
"""
else
:
print
>>
sio
,
"""
if (!
%(z_out)
s
|| (
%(z_out)
s->nd != 2)
|| (CudaNdarray_HOST_DIMS(
%(z_out)
s)[0] != CudaNdarray_HOST_DIMS(
%(z_in)
s)[0])
|| (CudaNdarray_HOST_DIMS(
%(z_out)
s)[1] != CudaNdarray_HOST_DIMS(
%(z_in)
s)[1])
)
{
Py_XDECREF(
%(z_out)
s);
%(z_out)
s = (CudaNdarray*)CudaNdarray_Copy(
%(z_in)
s);
if (!
%(z_out)
s)
{
%(fail)
s;
}
}
else
{
if (CudaNdarray_CopyFromCudaNdarray(
%(z_out)
s,
%(z_in)
s))
{
%(fail)
s;
}
}
"""
print
>>
sio
,
"""
if (CudaNdarray_sger(
%(name)
s_alpha,
%(x)
s,
%(y)
s,
%(z_out)
s))
{
%(fail)
s;
}
"""
return
sio
.
getvalue
()
%
locals
()
gpu_ger_no_inplace
=
GpuGer
(
inplace
=
False
)
gpu_ger_inplace
=
GpuGer
(
inplace
=
True
)
class
GpuOuter
(
Op
):
class
GpuOuter
(
Op
):
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
# we suppose type checking has been done, but make sure.
# we suppose type checking has been done, but make sure.
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
f2743791
...
@@ -3012,6 +3012,92 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
...
@@ -3012,6 +3012,92 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
return
0
;
return
0
;
}
}
int
CudaNdarray_sgemv
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
)
{
/**
* C <- alpha A B + beta C
* A : matrix
* B, C: vector
* alpha, beta: scalars
*/
if
(
A
->
nd
!=
2
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-matrix arg to gemv"
);
return
-
1
;
}
if
(
B
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg to gemv"
);
return
-
1
;
}
if
(
C
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg to gemv"
);
return
-
1
;
}
// We must allow dimensions to be zeros.
if
((
CudaNdarray_HOST_DIMS
(
A
)[
1
]
!=
CudaNdarray_HOST_DIMS
(
B
)[
0
])
||
(
CudaNdarray_HOST_DIMS
(
A
)[
0
]
!=
CudaNdarray_HOST_DIMS
(
C
)[
0
]))
{
PyErr_Format
(
PyExc_ValueError
,
"dimension mismatch in args to gemv (%i,%i)x(%i)->(%i)"
,
CudaNdarray_HOST_DIMS
(
A
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
CudaNdarray_HOST_DIMS
(
B
)[
0
],
CudaNdarray_HOST_DIMS
(
C
)[
0
]);
return
-
1
;
}
// a matrix has non-unit size and non-unit stride in both directions, we can't operate in-place
// TODO: make a copy instead of returning in error
if
(((
CudaNdarray_HOST_DIMS
(
A
)[
0
]
>
1
)
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
!=
1
))
&&
((
CudaNdarray_HOST_DIMS
(
A
)[
1
]
>
1
)
&&
(
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
!=
1
)))
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"non-unit stride in gemv arg"
);
return
-
1
;
}
// I don't know if cudablas handles negative strides
if
(
(
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
<
0
)
||
(
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
<
0
)
||
(
CudaNdarray_HOST_STRIDES
(
B
)[
0
]
<
0
)
||
(
CudaNdarray_HOST_STRIDES
(
C
)[
0
]
<
0
))
{
PyErr_Format
(
PyExc_ValueError
,
"illegal strides in args to gemv (%i,%i)x(%i)->(%i)"
,
CudaNdarray_HOST_STRIDES
(
A
)[
0
],
CudaNdarray_HOST_STRIDES
(
A
)[
1
],
CudaNdarray_HOST_STRIDES
(
B
)[
0
],
CudaNdarray_HOST_STRIDES
(
C
)[
0
]);
return
-
1
;
}
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
int
sa_0
=
(
CudaNdarray_HOST_DIMS
(
A
)[
0
]
>
1
)
?
CudaNdarray_HOST_STRIDES
(
A
)[
0
]
:
CudaNdarray_HOST_DIMS
(
A
)[
1
];
int
sa_1
=
(
CudaNdarray_HOST_DIMS
(
A
)[
1
]
>
1
)
?
CudaNdarray_HOST_STRIDES
(
A
)[
1
]
:
CudaNdarray_HOST_DIMS
(
A
)[
0
];
int
sb_0
=
(
CudaNdarray_HOST_DIMS
(
B
)[
0
]
>
1
)
?
CudaNdarray_HOST_STRIDES
(
B
)[
0
]
:
1
;
int
sc_0
=
(
CudaNdarray_HOST_DIMS
(
C
)[
0
]
>
1
)
?
CudaNdarray_HOST_STRIDES
(
C
)[
0
]
:
1
;
if
(
sa_0
==
1
)
{
cublasSgemv
(
'N'
,
CudaNdarray_HOST_DIMS
(
A
)[
0
],
CudaNdarray_HOST_DIMS
(
A
)[
1
],
alpha
,
CudaNdarray_DEV_DATA
(
A
),
sa_1
,
CudaNdarray_DEV_DATA
(
B
),
sb_0
,
beta
,
CudaNdarray_DEV_DATA
(
C
),
sc_0
);
}
else
if
(
sa_1
==
1
)
{
cublasSgemv
(
'T'
,
CudaNdarray_HOST_DIMS
(
A
)[
1
],
CudaNdarray_HOST_DIMS
(
A
)[
0
],
alpha
,
CudaNdarray_DEV_DATA
(
A
),
sa_0
,
CudaNdarray_DEV_DATA
(
B
),
sb_0
,
beta
,
CudaNdarray_DEV_DATA
(
C
),
sc_0
);
}
else
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"too many strides strides in sgemv"
);
return
-
1
;
}
CNDA_THREAD_SYNC
;
cudaError_t
err
=
cudaGetLastError
();
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"cublassGemv failed (%s)"
,
cudaGetErrorString
(
err
));
return
-
1
;
}
return
0
;
}
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
)
{
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
)
{
if
(
x
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg x to sger"
);
return
-
1
;
}
if
(
x
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg x to sger"
);
return
-
1
;
}
if
(
y
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg y to sger"
);
return
-
1
;
}
if
(
y
->
nd
!=
1
)
{
PyErr_SetString
(
PyExc_ValueError
,
"non-vector arg y to sger"
);
return
-
1
;
}
...
@@ -3033,17 +3119,50 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
...
@@ -3033,17 +3119,50 @@ int CudaNdarray_sger(float alpha, CudaNdarray * x, CudaNdarray * y, CudaNdarray
PyErr_SetString
(
PyExc_NotImplementedError
,
"non-c continugous A in sger"
);
PyErr_SetString
(
PyExc_NotImplementedError
,
"non-c continugous A in sger"
);
return
-
1
;
return
-
1
;
}
}
// Since Sger expects A in col-major, we invert x and y to fake this.
int
x_strides
=
CudaNdarray_HOST_STRIDES
(
x
)[
0
];
CudaNdarray
*
x_
=
x
;
if
(
x_strides
==
0
){
if
(
CudaNdarray_HOST_DIMS
(
x
)[
0
]
!=
1
){
PyErr_Format
(
PyExc_RuntimeError
,
"CudaNdarray_sger: Invalid input x(should not happen)."
" We received an CudaNdarray vector with a stride of 0"
" that have more then 1 elements!"
);
return
-
1
;
}
x_strides
=
4
;
}
else
if
(
x_strides
<
0
){
x_
=
(
CudaNdarray
*
)
CudaNdarray_Copy
(
x
);
x_strides
=
CudaNdarray_HOST_STRIDES
(
x_
)[
0
];
}
// Same for this, be safe
int
y_strides
=
CudaNdarray_HOST_STRIDES
(
y
)[
0
];
assert
(
CudaNdarray_HOST_STRIDES
(
x
)[
0
]
>=
0
);
CudaNdarray
*
y_
=
y
;
assert
(
CudaNdarray_HOST_STRIDES
(
y
)[
0
]
>=
0
);
if
(
y_strides
==
0
){
if
(
CudaNdarray_HOST_DIMS
(
y
)[
0
]
!=
1
){
PyErr_Format
(
PyExc_RuntimeError
,
"CudaNdarray_sger: Invalid input y(should not happen)."
" We received an CudaNdarray vector with a stride of 0"
" that have more then 1 elements!"
);
return
-
1
;
}
y_strides
=
4
;
}
else
if
(
y_strides
<
0
){
y_
=
(
CudaNdarray
*
)
CudaNdarray_Copy
(
y
);
y_strides
=
CudaNdarray_HOST_STRIDES
(
y_
)[
0
];
}
// Since Sger expects A in col-major, we invert x and y to fake this.
if
(
CudaNdarray_SIZE
(
A
)){
cublasSger
(
CudaNdarray_HOST_DIMS
(
y
)[
0
],
CudaNdarray_HOST_DIMS
(
x
)[
0
],
alpha
,
cublasSger
(
CudaNdarray_HOST_DIMS
(
y
)[
0
],
CudaNdarray_HOST_DIMS
(
x
)[
0
],
alpha
,
CudaNdarray_DEV_DATA
(
y
),
CudaNdarray_HOST_STRIDES
(
y
)[
0
],
CudaNdarray_DEV_DATA
(
y_
),
y_strides
,
CudaNdarray_DEV_DATA
(
x
),
CudaNdarray_HOST_STRIDES
(
x
)[
0
],
CudaNdarray_DEV_DATA
(
x_
),
x_strides
,
CudaNdarray_DEV_DATA
(
A
),
CudaNdarray_HOST_DIMS
(
A
)[
1
]);
CudaNdarray_DEV_DATA
(
A
),
CudaNdarray_HOST_DIMS
(
A
)[
1
]);
}
CNDA_THREAD_SYNC
;
CNDA_THREAD_SYNC
;
if
(
x_
!=
x
)
Py_DECREF
(
x_
);
if
(
y_
!=
y
)
Py_DECREF
(
y_
);
cudaError_t
err
=
cudaGetLastError
();
cudaError_t
err
=
cudaGetLastError
();
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
if
(
CUBLAS_STATUS_SUCCESS
!=
err
)
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
f2743791
...
@@ -320,6 +320,7 @@ DllExport bool CudaNdarray_is_c_contiguous(const CudaNdarray * self);
...
@@ -320,6 +320,7 @@ DllExport bool CudaNdarray_is_c_contiguous(const CudaNdarray * self);
DllExport
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
);
DllExport
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_sgemv
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_sum
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_sum
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
...
...
theano/sandbox/cuda/opt.py
浏览文件 @
f2743791
...
@@ -16,6 +16,10 @@ from theano.sandbox.cuda.basic_ops import *
...
@@ -16,6 +16,10 @@ from theano.sandbox.cuda.basic_ops import *
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
gpu_outer
,
GpuConv
)
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
gpu_outer
,
GpuConv
)
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_no_inplace
from
theano.sandbox.cuda.blas
import
(
GpuDownsampleFactorMax
,
from
theano.sandbox.cuda.blas
import
(
GpuDownsampleFactorMax
,
GpuDownsampleFactorMaxGrad
)
GpuDownsampleFactorMaxGrad
)
from
theano.sandbox.cuda.nnet
import
(
from
theano.sandbox.cuda.nnet
import
(
...
@@ -375,47 +379,85 @@ def local_gpu_dot22scalar(node):
...
@@ -375,47 +379,85 @@ def local_gpu_dot22scalar(node):
@register_opt
()
@register_opt
()
@local_optimizer
([])
@local_optimizer
([])
def
local_gpu_gemv
_as_gemm
(
node
):
def
local_gpu_gemv
(
node
):
"""
"""
gpu_from_host(gemv) -> gpu_gemv(gpu_from_host)
gpu_from_host(gemv) -> gpu_gemv(gpu_from_host)
gem
m
(host_from_gpu) -> host_from_gpu(gpu_gemv)
gem
v
(host_from_gpu) -> host_from_gpu(gpu_gemv)
This optimization solves the vector-matrix multiplication issue by
transforming the vector into a matrix, apply gpudot22 and reshaping
the output.
A more suitable solution would be to use the right cublas call
"""
"""
gemvs
=
{
tensor
.
blas
.
gemv_inplace
:
gpu_gemm_inplace
,
gemvs
=
{
tensor
.
blas
.
gemv_no_inplace
:
gpu_gemm_no_inplace
}
tensor
.
blas
.
gemv_inplace
:
gpu_gemv_no_inplace
,
tensor
.
blas
.
gemv_no_inplace
:
gpu_gemv_no_inplace
,
tensor
.
blas_c
.
CGemv
(
inplace
=
True
):
gpu_gemv_no_inplace
,
tensor
.
blas_c
.
CGemv
(
inplace
=
False
):
gpu_gemv_no_inplace
,
}
if
node
.
op
==
gpu_from_host
:
if
node
.
op
==
gpu_from_host
:
host_input
=
node
.
inputs
[
0
]
host_input
=
node
.
inputs
[
0
]
if
host_input
.
owner
and
host_input
.
owner
.
op
in
gemvs
:
if
host_input
.
owner
and
host_input
.
owner
.
op
in
gemvs
:
op
=
host_input
.
owner
.
op
op
=
host_input
.
owner
.
op
z
,
a
,
x
,
y
,
b
=
host_input
.
owner
.
inputs
z
,
a
,
x
,
y
,
b
=
host_input
.
owner
.
inputs
return
[
return
[
gemvs
[
op
](
GpuDimShuffle
((
False
,
True
),[
0
])(
gemvs
[
op
](
gpu_from_host
(
z
)
GpuDimShuffle
((
False
,),[
0
,
'x'
])(
gpu_from_host
(
z
))
,
a
,
a
,
gpu_from_host
(
x
)
,
gpu_from_host
(
x
)
,
GpuDimShuffle
((
False
,),[
0
,
'x'
])(
gpu_from_host
(
y
)
)
,
gpu_from_host
(
y
)
,
b
)
)
]
,
b
)]
if
node
.
op
in
gemvs
:
if
node
.
op
in
gemvs
:
z
,
a
,
x
,
y
,
b
=
node
.
inputs
z
,
a
,
x
,
y
,
b
=
node
.
inputs
x_on_gpu
=
(
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
)
x_on_gpu
=
(
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
)
y_on_gpu
=
(
y
.
owner
and
y
.
owner
.
op
==
host_from_gpu
)
y_on_gpu
=
(
y
.
owner
and
y
.
owner
.
op
==
host_from_gpu
)
z_on_gpu
=
(
z
.
owner
and
z
.
owner
.
op
==
host_from_gpu
)
z_on_gpu
=
(
z
.
owner
and
z
.
owner
.
op
==
host_from_gpu
)
if
x_on_gpu
or
y_on_gpu
or
z_on_gpu
:
if
x_on_gpu
or
y_on_gpu
or
z_on_gpu
:
return
[
host_from_gpu
(
GpuDimShuffle
((
False
,
True
),[
0
])(
return
[
host_from_gpu
(
gemvs
[
node
.
op
](
gemvs
[
node
.
op
](
GpuDimShuffle
((
False
,),[
0
,
'x'
])(
gpu_from_host
(
z
)
)
gpu_from_host
(
z
)
,
a
,
a
,
gpu_from_host
(
x
)
,
gpu_from_host
(
x
)
,
GpuDimShuffle
((
False
,),[
0
,
'x'
])(
gpu_from_host
(
y
)
)
,
gpu_from_host
(
y
)
,
b
))
)
]
,
b
))]
return
False
return
False
@register_opt
()
@local_optimizer
([])
def
local_gpu_ger
(
node
):
"""
gpu_from_host(ger) -> gpu_ger(gpu_from_host)
ger(host_from_gpu) -> host_from_gpu(gpu_ger)
"""
gers
=
{
tensor
.
blas_c
.
CGer
(
destructive
=
True
):
gpu_ger_no_inplace
,
tensor
.
blas_c
.
CGer
(
destructive
=
False
):
gpu_ger_no_inplace
,
tensor
.
blas
.
Ger
(
destructive
=
True
):
gpu_ger_no_inplace
,
tensor
.
blas
.
Ger
(
destructive
=
False
):
gpu_ger_no_inplace
,
}
if
node
.
op
==
gpu_from_host
:
host_input
=
node
.
inputs
[
0
]
if
host_input
.
owner
and
host_input
.
owner
.
op
in
gers
:
op
=
host_input
.
owner
.
op
z
,
a
,
x
,
y
=
host_input
.
owner
.
inputs
return
[
gers
[
op
](
gpu_from_host
(
z
)
,
a
,
gpu_from_host
(
x
)
,
gpu_from_host
(
y
)
)]
if
node
.
op
in
gers
:
z
,
a
,
x
,
y
=
node
.
inputs
x_on_gpu
=
(
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
)
y_on_gpu
=
(
y
.
owner
and
y
.
owner
.
op
==
host_from_gpu
)
z_on_gpu
=
(
z
.
owner
and
z
.
owner
.
op
==
host_from_gpu
)
if
x_on_gpu
or
y_on_gpu
or
z_on_gpu
:
return
[
host_from_gpu
(
gers
[
node
.
op
](
gpu_from_host
(
z
)
,
a
,
gpu_from_host
(
x
)
,
gpu_from_host
(
y
)
))]
return
False
@register_opt
()
@register_opt
()
@local_optimizer
([])
@local_optimizer
([])
def
local_gpu_gemm
(
node
):
def
local_gpu_gemm
(
node
):
...
@@ -424,7 +466,8 @@ def local_gpu_gemm(node):
...
@@ -424,7 +466,8 @@ def local_gpu_gemm(node):
gemm(host_from_gpu) -> host_from_gpu(gpu_gemm)
gemm(host_from_gpu) -> host_from_gpu(gpu_gemm)
"""
"""
gemms
=
{
tensor
.
blas
.
gemm_inplace
:
gpu_gemm_inplace
,
gemms
=
{
#tensor.blas.gemm_inplace: gpu_gemm_inplace,
tensor
.
blas
.
gemm_no_inplace
:
gpu_gemm_no_inplace
}
tensor
.
blas
.
gemm_no_inplace
:
gpu_gemm_no_inplace
}
if
node
.
op
==
gpu_from_host
:
if
node
.
op
==
gpu_from_host
:
host_input
=
node
.
inputs
[
0
]
host_input
=
node
.
inputs
[
0
]
...
@@ -924,15 +967,32 @@ def local_inplace_gemm(node):
...
@@ -924,15 +967,32 @@ def local_inplace_gemm(node):
if
node
.
op
==
gpu_gemm_no_inplace
:
if
node
.
op
==
gpu_gemm_no_inplace
:
return
[
gpu_gemm_inplace
(
*
node
.
inputs
)]
return
[
gpu_gemm_inplace
(
*
node
.
inputs
)]
@local_optimizer
([
gpu_gemv_no_inplace
])
def
local_inplace_gemv
(
node
):
if
node
.
op
==
gpu_gemv_no_inplace
:
return
[
gpu_gemv_inplace
(
*
node
.
inputs
)]
@local_optimizer
([
gpu_gemm_no_inplace
])
def
local_inplace_ger
(
node
):
if
node
.
op
==
gpu_ger_no_inplace
:
return
[
gpu_ger_inplace
(
*
node
.
inputs
)]
# After destroyhandler is in but before we try to make elemwise things inplace
# After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gpu gemm inplace
# Try to make gpu gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of
# Also, need to make the gemm optimisation(step 70) happen before the fusion of
# elemwise(step 71)
# elemwise(step 71)
optdb
.
register
(
'InplaceGpuBlasOpt'
,
optdb
.
register
(
'InplaceGpuBlasOpt'
,
EquilibriumOptimizer
([
local_inplace_gemm
],
failure_callback
=
EquilibriumOptimizer
.
warn_inplace
,
EquilibriumOptimizer
([
local_inplace_gemm
,
local_inplace_gemv
,
local_inplace_ger
,
],
failure_callback
=
EquilibriumOptimizer
.
warn_inplace
,
max_use_ratio
=
5
),
max_use_ratio
=
5
),
70.0
,
'fast_run'
,
'inplace'
,
'gpu'
)
70.0
,
'fast_run'
,
'inplace'
,
'gpu'
)
def
get_device_type_sizes
():
def
get_device_type_sizes
():
"""
"""
:return:(gpu ptr size, cpu ptr size, int sizes(gpu and cpu))
:return:(gpu ptr size, cpu ptr size, int sizes(gpu and cpu))
...
...
theano/sandbox/cuda/tests/test_blas.py
浏览文件 @
f2743791
from
unittest
import
TestCase
from
theano.compile.pfunc
import
pfunc
from
theano.compile.pfunc
import
pfunc
from
theano
import
tensor
from
theano
import
tensor
from
theano.tests
import
unittest_tools
from
theano.tests
import
unittest_tools
...
@@ -15,6 +17,9 @@ import theano.sandbox.cuda as tcn
...
@@ -15,6 +17,9 @@ import theano.sandbox.cuda as tcn
from
theano.tensor.signal.downsample
import
DownsampleFactorMax
,
DownsampleFactorMaxGrad
from
theano.tensor.signal.downsample
import
DownsampleFactorMax
,
DownsampleFactorMaxGrad
import
theano.compile.mode
import
theano.compile.mode
from
theano.tensor.tests.test_blas
import
BaseGemv
,
TestGer
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
,
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
,
gpu_ger_no_inplace
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
...
@@ -243,3 +248,39 @@ def test_downsample():
...
@@ -243,3 +248,39 @@ def test_downsample():
#We already check that the gpu version return the same value as the gpu version
#We already check that the gpu version return the same value as the gpu version
#for GpuDownsampleFactorMaxGrad. So no need to call verify_grad here.
#for GpuDownsampleFactorMaxGrad. So no need to call verify_grad here.
class
TestGpuGemv
(
TestCase
,
BaseGemv
,
unittest_tools
.
TestOptimizationMixin
):
mode
=
mode_with_gpu
dtype
=
'float32'
# As all input are transfered to the gpu, this allow to make all
# the gemv inplace.
gemv
=
gpu_gemv_inplace
gemv_inplace
=
gpu_gemv_inplace
class
TestGpuGer
(
TestGer
):
def
setUp
(
self
):
self
.
mode
=
mode_with_gpu
dtype
=
self
.
dtype
=
'float32'
# optimization isn't dtype-dependent
self
.
A
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
a
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
self
.
x
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
y
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
ger
=
gpu_ger_no_inplace
self
.
ger_destructive
=
gpu_ger_inplace
self
.
gemm
=
tcn
.
blas
.
gpu_gemm_no_inplace
# data on the gpu make the op always inplace
self
.
ger
=
gpu_ger_inplace
self
.
gemm
=
tcn
.
blas
.
gpu_gemm_inplace
class
TestGpuGer_OpContract
(
TestCase
,
unittest_tools
.
T_OpContractMixin
):
def
setUp
(
self
):
self
.
ops
=
[
gpu_ger_no_inplace
,
gpu_ger_inplace
]
def
clone
(
self
,
op
):
return
tcn
.
blas
.
GpuGer
(
op
.
inplace
)
theano/tensor/__init__.py
浏览文件 @
f2743791
...
@@ -10,6 +10,7 @@ import opt
...
@@ -10,6 +10,7 @@ import opt
import
opt_uncanonicalize
import
opt_uncanonicalize
import
blas
import
blas
import
blas_scipy
import
blas_scipy
import
blas_c
import
xlogx
import
xlogx
import
raw_random
import
raw_random
...
...
theano/tensor/basic.py
浏览文件 @
f2743791
...
@@ -2540,6 +2540,41 @@ class Alloc(gof.Op):
...
@@ -2540,6 +2540,41 @@ class Alloc(gof.Op):
#reuse the allocated memory.
#reuse the allocated memory.
out
[
0
][
...
]
=
v
# broadcast v to fill us up
out
[
0
][
...
]
=
v
# broadcast v to fill us up
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
# TODO: use the elemwise code generator here
if
node
.
inputs
[
0
]
.
ndim
==
0
:
# filling with a scalar is a common use of alloc
# that we can implement relatively easily
vv
=
inp
[
0
]
zz
,
=
out
fail
=
sub
[
'fail'
]
if
node
.
outputs
[
0
]
.
ndim
==
1
:
N0
=
inp
[
1
]
return
"""
npy_intp N0 = ((dtype_
%(N0)
s*)
%(N0)
s->data)[0];
dtype_
%(vv)
s vv;
dtype_
%(zz)
s* zz;
if ((NULL ==
%(zz)
s) || (
%(zz)
s->dimensions[0] != N0))
{
if (
%(zz)
s) Py_XDECREF(
%(zz)
s);
%(zz)
s = (PyArrayObject*)PyArray_SimpleNew(1,
&N0, type_num_
%(vv)
s);
if(!
%(zz)
s) {
PyErr_SetString(PyExc_MemoryError, "alloc failed");
%(fail)
s
}
}
vv = ((dtype_
%(vv)
s*)
%(vv)
s->data)[0];
zz = ((dtype_
%(zz)
s*)
%(zz)
s->data);
assert (
%(zz)
s->strides[0] == sizeof(dtype_
%(zz)
s));
for (int i = 0; i < N0; ++i)
{
zz[i] = vv;
}
"""
%
locals
()
# else pretend this never happened
return
super
(
Alloc
,
self
)
.
c_code
(
node
,
name
,
inp
,
out
,
sub
)
def
infer_shape
(
self
,
node
,
input_shapes
):
def
infer_shape
(
self
,
node
,
input_shapes
):
return
[
node
.
inputs
[
1
:]]
return
[
node
.
inputs
[
1
:]]
...
...
theano/tensor/blas.py
浏览文件 @
f2743791
...
@@ -6,18 +6,26 @@ Learn more about BLAS here:
...
@@ -6,18 +6,26 @@ Learn more about BLAS here:
The standard BLAS libraries implement what is called "legacy BLAS" in that
The standard BLAS libraries implement what is called "legacy BLAS" in that
document.
document.
This documentation section describes Theano's BLAS optimization
This documentation describes Theano's BLAS optimization pipeline.
pipeline.
Where there is a discrepancy between how things do work and how they *should*
Where there is a discrepancy between how things do work and how they *should*
work, both aspects should be documented. It helps keep a broader agenda in view
work, both aspects should be documented.
even while fixing little bugs etc. from day to day.
There are four kinds of BLAS Ops in Theano:
- Python implementations (this file)
- SciPy-based (blas_scipy)
- C-based (blas_c)
- CUDA-based (theano.sandbox.cuda.blas)
:note: Unfortunately (because it's confusing) this file currently contains Ops
that contain both Python and C versions. I think it would be better to
move the C implementations to blas_c so that this file is pure Python.
-JB
Ops
Ops
===
===
There are two BLAS calls wrapped in this module: GEMM and GEMV.
GEMM: Dot22, Dot22Scalar, GemmRelated, Gemm
GEMM: Dot22, Dot22Scalar, GemmRelated, Gemm
-------------------------------------------
-------------------------------------------
...
@@ -43,18 +51,19 @@ GEMV: Gemv
...
@@ -43,18 +51,19 @@ GEMV: Gemv
----------
----------
The BLAS GEMV operation implements Z <- a X Y + b Z,
The BLAS GEMV operation implements Z <- a X Y + b Z,
where Z is a matrix, Y, and Z are vectors, and a and b are scalars.
where X is a matrix, Y, and Z are vectors, and a and b are scalars.
Gemv implements the GEMV call in all its generality.
GER: Ger
--------
The BLAS GER operation implements Z <- a X' Y + Z,
where X and Y are vectors, and matrix Z gets a rank-1 update.
Other Notable BLAS-related Ops
Other Notable BLAS-related Ops
------------------------------
------------------------------
GpuOuter is currently a wrapper around GER. GER is a useful special case of
GEMM, and in the future it would be good to have a GER Op. With a GER Op here,
the GpuOuter could be turned into a GpuGER.
SYRK is another useful special case of GEMM. Particularly SYRK preserves
SYRK is another useful special case of GEMM. Particularly SYRK preserves
symmetry in the matrix that it updates. See how the linear-algebra module uses
symmetry in the matrix that it updates. See how the linear-algebra module uses
symmetry hints before implementing this Op, so that this Op is compatible with
symmetry hints before implementing this Op, so that this Op is compatible with
...
@@ -64,14 +73,19 @@ that system.
...
@@ -64,14 +73,19 @@ that system.
Optimizations
Optimizations
=============
=============
The current optimization pipeline is not exactly clear to me. Instead I will
The optimization pipeline works something like this:
describe how it should work.
The high level pipeline is:
1. identify dot22 from dot
1. identify dot22 from dot
2. identify gemm from dot22
2. identify gemm from dot22
3. identify dot22scalar from dot22 that are not gemm
3. identify dot22scalar from dot22 that are not gemm
4. specialize gemm to gemv where applicable
4. specialize gemm to gemv where applicable
5. specialize gemm to ger where applicable
6. specialize dot22 -> gemv or ger where applicable
:note: GEMM is the most canonical BLAS signature that we deal with so far, it
would be good to turn most things into GEMM (dot, inner, outer, dot22,
dot22scalar), and then to specialize from gemm to the various other L2 and
L3 operations.
Identify Dot22
Identify Dot22
--------------
--------------
...
@@ -161,9 +175,9 @@ class Gemv(Op):
...
@@ -161,9 +175,9 @@ class Gemv(Op):
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
inplace
:
if
self
.
inplace
:
return
'
Gemv{inplace}'
return
'
%
s{inplace}'
%
self
.
__class__
.
__name__
else
:
else
:
return
'
Gemv{no_inplace}'
return
'
%
s{no_inplace}'
%
self
.
__class__
.
__name__
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
^
hash
(
self
.
inplace
)
return
hash
(
type
(
self
))
^
hash
(
self
.
inplace
)
...
@@ -268,31 +282,19 @@ class Ger(Op):
...
@@ -268,31 +282,19 @@ class Ger(Op):
raise
TypeError
(
'only float and complex types supported'
,
x
.
dtype
)
raise
TypeError
(
'only float and complex types supported'
,
x
.
dtype
)
return
Apply
(
self
,
[
A
,
alpha
,
x
,
y
],
[
A
.
type
()])
return
Apply
(
self
,
[
A
,
alpha
,
x
,
y
],
[
A
.
type
()])
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
def
perform
(
self
,
node
,
inp
,
out
):
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
cA
,
calpha
,
cx
,
cy
=
inp
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
cZ
,
=
out
if
self
.
destructive
:
# get vars for containers
A
=
cA
cA
,
calpha
,
cx
,
cy
=
node_input_storage
else
:
cZ
,
=
node_output_storage
A
=
cA
.
copy
()
if
calpha
!=
1
:
A
+=
calpha
*
numpy
.
outer
(
cx
,
cy
)
else
:
A
+=
numpy
.
outer
(
cx
,
cy
)
cZ
[
0
]
=
A
def
rval
():
if
self
.
destructive
:
A
=
cA
[
0
]
else
:
A
=
cA
[
0
]
.
copy
()
if
calpha
[
0
]
!=
1
:
A
+=
calpha
[
0
]
*
numpy
.
outer
(
cx
[
0
],
cy
[
0
])
else
:
A
+=
numpy
.
outer
(
cx
[
0
],
cy
[
0
])
cZ
[
0
]
=
A
#TODO: If this is currently an unofficial part of the thunk API,
# then maybe it should be documented and made official?
rval
.
inputs
=
node_input_storage
rval
.
outputs
=
node_output_storage
rval
.
lazy
=
False
return
rval
ger
=
Ger
(
destructive
=
False
)
ger
=
Ger
(
destructive
=
False
)
ger_destructive
=
Ger
(
destructive
=
True
)
ger_destructive
=
Ger
(
destructive
=
True
)
...
@@ -1148,7 +1150,7 @@ def _gemm_from_factored_list(lst):
...
@@ -1148,7 +1150,7 @@ def _gemm_from_factored_list(lst):
# Try every pair in the sM_list, trying to turn it into a gemm operation
# Try every pair in the sM_list, trying to turn it into a gemm operation
for
i
in
xrange
(
len
(
lst
)
-
1
):
for
i
in
xrange
(
len
(
lst
)
-
1
):
s_i
,
M_i
=
lst
[
i
]
s_i
,
M_i
=
lst
[
i
]
for
j
in
xrange
(
i
+
1
,
len
(
lst
)):
for
j
in
xrange
(
i
+
1
,
len
(
lst
)):
s_j
,
M_j
=
lst
[
j
]
s_j
,
M_j
=
lst
[
j
]
...
@@ -1400,9 +1402,7 @@ def local_gemm_to_ger(node):
...
@@ -1400,9 +1402,7 @@ def local_gemm_to_ger(node):
rval
=
ger
(
z
,
a
,
xv
,
yv
)
rval
=
ger
(
z
,
a
,
xv
,
yv
)
return
[
rval
]
return
[
rval
]
elif
bval
==
0
:
# GER on zeros_like should be faster than GEMM
elif
bval
==
0
:
# GER on zeros_like should be faster than GEMM
zeros
=
T
.
alloc
(
zeros
=
T
.
zeros
([
x
.
shape
[
0
],
y
.
shape
[
1
]],
x
.
dtype
)
numpy
.
asarray
(
0
,
dtype
=
x
.
dtype
),
x
.
shape
[
0
],
y
.
shape
[
1
])
rval
=
ger
(
zeros
,
a
,
xv
,
yv
)
rval
=
ger
(
zeros
,
a
,
xv
,
yv
)
return
[
rval
]
return
[
rval
]
else
:
else
:
...
@@ -1414,20 +1414,43 @@ def local_gemm_to_ger(node):
...
@@ -1414,20 +1414,43 @@ def local_gemm_to_ger(node):
#TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
#TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working
# working
@local_optimizer
([
_dot22
])
@local_optimizer
([
_dot22
])
def
local_dot22_to_ger
(
node
):
def
local_dot22_to_ger
_or_gemv
(
node
):
"""
GEMM
computing an outer-product -> GER
"""
dot22
computing an outer-product -> GER
"""
"""
if
node
.
op
==
_dot22
:
if
node
.
op
==
_dot22
:
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
if
x
.
broadcastable
[
1
]
and
y
.
broadcastable
[
0
]:
xb
=
x
.
broadcastable
yb
=
y
.
broadcastable
one
=
T
.
as_tensor_variable
(
numpy
.
asarray
(
1
,
dtype
=
x
.
dtype
))
zero
=
T
.
as_tensor_variable
(
numpy
.
asarray
(
0
,
dtype
=
x
.
dtype
))
if
xb
[
1
]
and
yb
[
0
]:
# x and y are both vectors so this might qualifies for a GER
# x and y are both vectors so this might qualifies for a GER
xv
=
x
.
dimshuffle
(
0
)
xv
=
x
.
dimshuffle
(
0
)
yv
=
y
.
dimshuffle
(
1
)
yv
=
y
.
dimshuffle
(
1
)
one
=
T
.
as_tensor_variable
(
numpy
.
asarray
(
1
,
dtype
=
x
.
dtype
))
zeros
=
T
.
zeros
([
x
.
shape
[
0
],
y
.
shape
[
1
]],
dtype
=
x
.
dtype
)
zeros
=
T
.
alloc
(
numpy
.
asarray
(
0
,
dtype
=
x
.
dtype
),
x
.
shape
[
0
],
y
.
shape
[
1
])
rval
=
ger
(
zeros
,
one
,
xv
,
yv
)
rval
=
ger
(
zeros
,
one
,
xv
,
yv
)
return
[
rval
]
return
[
rval
]
if
xb
[
0
]
and
yb
[
1
]:
# x and y are both vectors so this qualifies for a sdot / ddot
# TODO: Theano doesn't have a sdot, but gemv is better than _dot22
xv
=
x
.
dimshuffle
(
1
)
zeros
=
T
.
zeros
([
1
],
x
.
dtype
)
rval
=
gemv_no_inplace
(
zeros
,
one
,
y
.
T
,
xv
,
one
)
return
[
rval
.
dimshuffle
(
'x'
,
0
)]
if
xb
[
0
]
and
not
yb
[
0
]
and
not
yb
[
1
]:
# x is vector, y is matrix so try gemv
xv
=
x
.
dimshuffle
(
1
)
zeros
=
T
.
zeros
([
y
.
shape
[
1
]],
x
.
dtype
)
rval
=
gemv_no_inplace
(
zeros
,
one
,
y
.
T
,
xv
,
one
)
return
[
rval
.
dimshuffle
(
'x'
,
0
)]
if
not
xb
[
0
]
and
not
xb
[
1
]
and
yb
[
1
]:
# x is matrix, y is vector, try gemv
yv
=
y
.
dimshuffle
(
0
)
zeros
=
T
.
zeros
([
x
.
shape
[
0
]],
dtype
=
x
.
dtype
)
rval
=
gemv_no_inplace
(
zeros
,
one
,
x
,
yv
,
one
)
return
[
rval
.
dimshuffle
(
0
,
'x'
)]
#################################
#################################
#
#
...
@@ -1445,14 +1468,14 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
...
@@ -1445,14 +1468,14 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
blas_optdb
.
register
(
'local_dot_to_dot22'
,
blas_optdb
.
register
(
'local_dot_to_dot22'
,
EquilibriumOptimizer
([
local_dot_to_dot22
],
max_use_ratio
=
5
),
EquilibriumOptimizer
([
local_dot_to_dot22
],
max_use_ratio
=
5
),
0
,
'fast_run'
)
0
,
'fast_run'
)
blas_optdb
.
register
(
'
local_dot_to_gemm
'
,
blas_optdb
.
register
(
'
gemm_optimizer
'
,
GemmOptimizer
(),
GemmOptimizer
(),
10
,
'fast_run'
)
10
,
'fast_run'
)
blas_optdb
.
register
(
'local_gemm_to_gemv'
,
blas_optdb
.
register
(
'local_gemm_to_gemv'
,
EquilibriumOptimizer
([
EquilibriumOptimizer
([
local_gemm_to_gemv
,
local_gemm_to_gemv
,
local_gemm_to_ger
,
local_gemm_to_ger
,
local_dot22_to_ger
,
local_dot22_to_ger
_or_gemv
,
local_dimshuffle_lift
],
local_dimshuffle_lift
],
max_use_ratio
=
5
),
max_use_ratio
=
5
),
15
,
'fast_run'
)
15
,
'fast_run'
)
...
...
theano/tensor/blas_c.py
0 → 100644
浏览文件 @
f2743791
from
theano.gof
import
Op
from
blas
import
ldflags
,
blas_header_text
from
blas
import
blas_optdb
,
optdb
,
local_optimizer
,
EquilibriumOptimizer
from
blas
import
Ger
,
ger
,
ger_destructive
from
blas
import
Gemv
,
gemv_inplace
,
gemv_no_inplace
class
BaseBLAS
(
object
):
def
c_libraries
(
self
):
return
ldflags
()
def
c_compile_args
(
self
):
return
ldflags
(
libs
=
False
,
flags
=
True
)
def
c_lib_dirs
(
self
):
return
ldflags
(
libs
=
False
,
libs_dir
=
True
)
def
c_header_dirs
(
self
):
return
ldflags
(
libs
=
False
,
include_dir
=
True
)
def
c_support_code
(
self
):
return
blas_header_text
()
####### ####### #######
# GER
####### ####### #######
def
ger_c_code
(
A
,
a
,
x
,
y
,
Z
,
destructive
,
fail
):
return
"""
int elemsize ;
if (
%(A)
s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2");
%(fail)
s;}
if (
%(x)
s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 1");
%(fail)
s;}
if (
%(y)
s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 1");
%(fail)
s;}
if (
%(a)
s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0");
%(fail)
s;}
if (
%(A)
s->descr->type_num !=
%(x)
s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. x");
%(fail)
s; }
if (
%(A)
s->descr->type_num !=
%(y)
s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. y");
%(fail)
s; }
if (
%(A)
s->dimensions[0] !=
%(x)
s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]");
%(fail)
s;}
if (
%(A)
s->dimensions[1] !=
%(y)
s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]");
%(fail)
s;}
if (
%(A)
s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (
%(A)
s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex CGer");
%(fail)
s;}
// copy A if !self.destructive or A is fully strided
if (!
%(destructive)
s
|| ((
%(A)
s->strides[0] != elemsize)
&&
(
%(A)
s->strides[1] != elemsize)))
{
npy_intp dims[2];
dims[0] =
%(A)
s->dimensions[0];
dims[1] =
%(A)
s->dimensions[1];
if ((NULL ==
%(Z)
s)
|| (
%(Z)
s->dimensions[0] !=
%(A)
s->dimensions[0])
|| (
%(Z)
s->dimensions[1] !=
%(A)
s->dimensions[1]))
{
if (
%(Z)
s) Py_XDECREF(
%(Z)
s);
%(Z)
s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(
%(A)
s));
if(!
%(Z)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc ger output");
%(fail)
s
}
}
assert (
%(Z)
s !=
%(A)
s);
if (
%(Z)
s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)
%(Z)
s->data;
const float * zdata = (float*)
%(A)
s->data;
int Ai =
%(A)
s->strides[0]/sizeof(float);
int Aj =
%(A)
s->strides[1]/sizeof(float);
int Zi =
%(Z)
s->strides[0]/sizeof(float);
int Zj =
%(Z)
s->strides[1]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
}
}
}
else if (
%(Z)
s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*)
%(Z)
s->data;
const double * zdata = (double*)
%(A)
s->data;
int Ai =
%(A)
s->strides[0]/sizeof(double);
int Aj =
%(A)
s->strides[1]/sizeof(double);
int Zi =
%(Z)
s->strides[0]/sizeof(double);
int Zj =
%(Z)
s->strides[1]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j];
}
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)
s
}
}
else
{
//fprintf(stderr, "USING A
\\
n");
if (
%(Z)
s !=
%(A)
s)
{
if (
%(Z)
s) { Py_DECREF(
%(Z)
s); }
%(Z)
s =
%(A)
s;
Py_INCREF(
%(Z)
s);
}
}
{
int Nz0 =
%(Z)
s->dimensions[0];
int Nz1 =
%(Z)
s->dimensions[1];
int Sz0 =
%(Z)
s->strides[0] / elemsize;
int Sz1 =
%(Z)
s->strides[1] / elemsize;
int Sx =
%(x)
s->strides[0] / elemsize;
int Sy =
%(y)
s->strides[0] / elemsize;
if (1)
{
if (
%(Z)
s->strides[0] == elemsize)
{
if (
%(Z)
s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "A
\\
n");
float alpha = ((dtype_
%(a)
s*)
%(a)
s->data)[0];
sger_(&Nz0, &Nz1, &alpha,
(float*)(
%(x)
s->data), &Sx,
(float*)(
%(y)
s->data), &Sy,
(float*)(
%(Z)
s->data), &Sz1);
}
else if (
%(Z)
s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_
%(a)
s*)
%(a)
s->data)[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)(
%(x)
s->data), &Sx,
(double*)(
%(y)
s->data), &Sy,
(double*)(
%(Z)
s->data), &Sz1);
}
else { assert(0); }
}
else if (
%(Z)
s->strides[1] == elemsize)
{
if (
%(Z)
s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B
%%
i
%%
i
%%
i
%%
i
\\
n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_
%(a)
s*)(
%(a)
s->data))[0];
//fprintf(stderr, "alpha=
%%
f
\\
n", alpha);
//fprintf(stderr, "sx sy
%%
i
%%
i
\\
n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha,
(float*)(
%(y)
s->data), &Sy,
(float*)(
%(x)
s->data), &Sx,
(float*)(
%(Z)
s->data), &Sz0);
}
else if (
%(Z)
s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_
%(a)
s*)
%(a)
s->data)[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)(
%(y)
s->data), &Sy,
(double*)(
%(x)
s->data), &Sx,
(double*)(
%(Z)
s->data), &Sz0);
}
else { assert(0); }
}
else { assert(0); }
}
}
"""
%
locals
()
class
CGer
(
BaseBLAS
,
Ger
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
A
,
a
,
x
,
y
=
inp
Z
,
=
out
code
=
ger_c_code
(
A
,
a
,
x
,
y
,
Z
,
destructive
=
int
(
self
.
destructive
),
fail
=
sub
[
'fail'
])
return
code
def
c_code_cache_version
(
self
):
return
(
2
,)
@local_optimizer
([
ger
,
ger_destructive
])
def
use_c_ger
(
node
):
if
node
.
op
==
ger
:
print
"inserting C_GER"
return
[
CGer
(
False
)(
*
node
.
inputs
)]
if
node
.
op
==
ger_destructive
:
print
"inserting dstruc C_GER"
return
[
CGer
(
True
)(
*
node
.
inputs
)]
@local_optimizer
([
CGer
(
False
)])
def
make_c_ger_destructive
(
node
):
if
node
.
op
==
CGer
(
False
):
print
"inserting destructive C_GER"
return
[
CGer
(
True
)(
*
node
.
inputs
)]
####### ####### #######
# GEMV
####### ####### #######
def
gemv_c_code
(
aa
,
xx
,
yy
,
zz
,
alpha
,
beta
,
destructive
,
fail
):
"""
zz <- beta * aa + alpha * dot(xx, yy)
where xx is a matrix, yy and aa are vectors (ergo zz is vector)
"""
return
"""
int elemsize ;
float fbeta;
double dbeta;
if (
%(aa)
s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1");
%(fail)
s;}
if (
%(xx)
s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2");
%(fail)
s;}
if (
%(yy)
s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1");
%(fail)
s;}
if (
%(alpha)
s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0");
%(fail)
s;}
if (
%(beta)
s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0");
%(fail)
s;}
if (
%(aa)
s->descr->type_num !=
%(xx)
s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx");
%(fail)
s; }
if (
%(aa)
s->descr->type_num !=
%(yy)
s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy");
%(fail)
s; }
if (
%(xx)
s->dimensions[0] !=
%(aa)
s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[0] != x.shape[0]");
%(fail)
s;}
if (
%(xx)
s->dimensions[1] !=
%(yy)
s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "A.shape[1] != y.shape[0]");
%(fail)
s;}
if (
%(aa)
s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (
%(aa)
s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)
s;}
fbeta = dbeta = ((dtype_
%(beta)
s*)
%(beta)
s->data)[0];
// copy aa if not destructive
if (!
%(destructive)
s)
{
if ((NULL ==
%(zz)
s)
|| (
%(zz)
s->dimensions[0] !=
%(aa)
s->dimensions[0]))
{
if (
%(zz)
s) Py_XDECREF(
%(zz)
s);
%(zz)
s = (PyArrayObject*)PyArray_SimpleNew(1,
%(aa)
s->dimensions, type_num_
%(aa)
s);
if(!
%(zz)
s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemv output");
%(fail)
s
}
}
assert (
%(zz)
s !=
%(aa)
s);
if (dbeta != 0)
{
if (
%(zz)
s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)
%(zz)
s->data;
const float * zdata = (float*)
%(aa)
s->data;
int Ai =
%(aa)
s->strides[0]/sizeof(float);
int Zi =
%(zz)
s->strides[0]/sizeof(float);
for (int i = 0; i <
%(aa)
s->dimensions[0]; ++i)
{
zoutdata[Zi*i] = fbeta * zdata[Ai*i];
}
}
else if (
%(xx)
s->descr->type_num == PyArray_DOUBLE)
{
double * zoutdata = (double*)
%(zz)
s->data;
const double * zdata = (double*)
%(aa)
s->data;
int Ai =
%(aa)
s->strides[0]/sizeof(double);
int Zi =
%(zz)
s->strides[0]/sizeof(double);
for (int i = 0; i <
%(aa)
s->dimensions[0]; ++i)
{
zoutdata[Zi*i] = dbeta * zdata[Ai*i];
}
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)
s
}
fbeta = dbeta = 1.0;
}
}
else
{
//fprintf(stderr, "Gemv working in-place
\\
n");
if (
%(zz)
s !=
%(aa)
s)
{
if (
%(zz)
s) { Py_DECREF(
%(zz)
s); }
%(zz)
s =
%(aa)
s;
Py_INCREF(
%(zz)
s);
}
}
{
char TRANS = 'T';
char NOTRANS = 'N';
int Nx0 =
%(xx)
s->dimensions[0];
int Nx1 =
%(xx)
s->dimensions[1];
int Sx0 =
%(xx)
s->strides[0] / elemsize;
int Sx1 =
%(xx)
s->strides[1] / elemsize;
int Sz =
%(zz)
s->strides[0] / elemsize;
int Sy =
%(yy)
s->strides[0] / elemsize;
if (Nx0 * Nx1)
{
if (
%(xx)
s->strides[0] == elemsize)
{
if (
%(xx)
s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "A
\\
n");
float alpha = ((dtype_
%(alpha)
s*)
%(alpha)
s->data)[0];
sgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(float*)(
%(xx)
s->data), &Sx1,
(float*)(
%(yy)
s->data), &Sy,
&fbeta,
(float*)(
%(zz)
s->data), &Sz);
}
else if (
%(xx)
s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_
%(alpha)
s*)
%(alpha)
s->data)[0];
dgemv_(&NOTRANS, &Nx0, &Nx1,
&alpha,
(double*)(
%(xx)
s->data), &Sx1,
(double*)(
%(yy)
s->data), &Sy,
&dbeta,
(double*)(
%(zz)
s->data), &Sz);
}
else
{
assert(0);
}
}
else if (
%(xx)
s->strides[1] == elemsize)
{
if (
%(xx)
s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B
%%
i
%%
i
%%
i
%%
i
\\
n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_
%(alpha)
s*)
%(alpha)
s->data)[0];
//fprintf(stderr, "alpha=
%%
f
\\
n", alpha);
//fprintf(stderr, "sx sy
%%
i
%%
i
\\
n", Sx, Sy);
sgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(float*)(
%(xx)
s->data), &Sx0,
(float*)(
%(yy)
s->data), &Sy,
&fbeta,
(float*)(
%(zz)
s->data), &Sz);
}
else if (
%(xx)
s->descr->type_num == PyArray_DOUBLE)
{
double alpha = ((dtype_
%(alpha)
s*)
%(alpha)
s->data)[0];
dgemv_(&TRANS, &Nx1, &Nx0,
&alpha,
(double*)(
%(xx)
s->data), &Sx0,
(double*)(
%(yy)
s->data), &Sy,
&dbeta,
(double*)(
%(zz)
s->data), &Sz);
}
else
{
assert(0);
}
}
else
{
// if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
}
}
else if (dbeta != 1.0)
{
// the matrix has at least one dim of length 0
// so we do this loop, which either iterates over 0 elements
// or else it does the right thing for length-0 x.
dtype_
%(zz)
s * zptr = (dtype_
%(zz)
s*)(
%(zz)
s->data);
for (int i = 0; i < Nx0; ++i)
{
zptr[i * Sz] *= dbeta;
}
}
}
"""
%
locals
()
class
CGemv
(
BaseBLAS
,
Gemv
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
aa
,
alpha
,
xx
,
yy
,
beta
=
inp
zz
,
=
out
code
=
gemv_c_code
(
aa
,
xx
,
yy
,
zz
,
alpha
,
beta
,
destructive
=
int
(
self
.
inplace
),
fail
=
sub
[
'fail'
])
return
code
def
c_code_cache_version
(
self
):
return
(
1
,)
@local_optimizer
([
gemv_inplace
,
gemv_no_inplace
])
def
use_c_gemv
(
node
):
if
node
.
op
==
gemv_no_inplace
:
return
[
CGemv
(
inplace
=
False
)(
*
node
.
inputs
)]
if
node
.
op
==
gemv_inplace
:
return
[
CGemv
(
inplace
=
True
)(
*
node
.
inputs
)]
@local_optimizer
([
CGemv
(
inplace
=
False
)])
def
make_c_gemv_destructive
(
node
):
if
node
.
op
==
CGemv
(
inplace
=
False
):
return
[
CGemv
(
inplace
=
True
)(
*
node
.
inputs
)]
####### ####### #######
# Optimizers
####### ####### #######
blas_optdb
.
register
(
'use_c_blas'
,
EquilibriumOptimizer
([
use_c_ger
,
use_c_gemv
,
],
max_use_ratio
=
5
),
20
,
'fast_run'
,
'c_blas'
)
#print 'BLAS_OPTDB'
#print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py
optdb
.
register
(
'c_blas_destructive'
,
EquilibriumOptimizer
([
make_c_ger_destructive
,
make_c_gemv_destructive
,
],
failure_callback
=
EquilibriumOptimizer
.
warn_inplace
,
max_use_ratio
=
5
),
70.0
,
'fast_run'
,
'inplace'
,
'c_blas'
)
theano/tensor/blas_scipy.py
浏览文件 @
f2743791
...
@@ -23,10 +23,9 @@ try:
...
@@ -23,10 +23,9 @@ try:
numpy
.
dtype
(
'complex64'
):
scipy
.
linalg
.
blas
.
fblas
.
cgeru
,
numpy
.
dtype
(
'complex64'
):
scipy
.
linalg
.
blas
.
fblas
.
cgeru
,
numpy
.
dtype
(
'complex128'
):
scipy
.
linalg
.
blas
.
fblas
.
zgeru
,
numpy
.
dtype
(
'complex128'
):
scipy
.
linalg
.
blas
.
fblas
.
zgeru
,
}
}
optimizations_enabled
=
True
except
ImportError
,
e
:
except
ImportError
,
e
:
have_fblas
=
False
have_fblas
=
False
optimizations_enabled
=
False
class
ScipyGer
(
Ger
):
class
ScipyGer
(
Ger
):
...
@@ -62,13 +61,11 @@ class ScipyGer(Ger):
...
@@ -62,13 +61,11 @@ class ScipyGer(Ger):
@local_optimizer
([
ger
,
ger_destructive
])
@local_optimizer
([
ger
,
ger_destructive
])
def
use_scipy_ger
(
node
):
def
use_scipy_ger
(
node
):
if
not
optimizations_enabled
:
return
if
node
.
op
==
ger
:
if
node
.
op
==
ger
:
return
[
ScipyGer
(
False
)(
*
node
.
inputs
)]
return
[
ScipyGer
(
False
)(
*
node
.
inputs
)]
@local_optimizer
([
ScipyGer
(
False
)])
@local_optimizer
([
ScipyGer
(
False
)])
def
make_ger_destructive
(
node
):
def
make_ger_destructive
(
node
):
if
not
optimizations_enabled
:
return
if
node
.
op
==
ScipyGer
(
False
):
if
node
.
op
==
ScipyGer
(
False
):
return
[
ScipyGer
(
True
)(
*
node
.
inputs
)]
return
[
ScipyGer
(
True
)(
*
node
.
inputs
)]
...
...
theano/tensor/tests/test_blas.py
浏览文件 @
f2743791
...
@@ -29,12 +29,15 @@ from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
...
@@ -29,12 +29,15 @@ from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
#, constant, eval_outputs)
#, constant, eval_outputs)
import
theano.tensor.blas_scipy
import
theano.tensor.blas_scipy
if
config
.
mode
==
'FAST_COMPILE'
:
if
config
.
mode
==
'FAST_COMPILE'
:
mode_not_fast_compile
=
'FAST_RUN'
mode_not_fast_compile
=
'FAST_RUN'
else
:
else
:
mode_not_fast_compile
=
config
.
mode
mode_not_fast_compile
=
config
.
mode
mode_blas_opt
=
theano
.
compile
.
get_default_mode
()
.
including
(
'BlasOpt'
,
'specialize'
)
mode_blas_opt
=
theano
.
compile
.
get_default_mode
()
.
including
(
'BlasOpt'
,
'specialize'
,
'InplaceBlasOpt'
)
mode_blas_opt
=
mode_blas_opt
.
excluding
(
'c_blas'
)
def
test_dot_eq
():
def
test_dot_eq
():
assert
T
.
Dot
()
==
T
.
Dot
()
assert
T
.
Dot
()
==
T
.
Dot
()
...
@@ -550,21 +553,6 @@ def test_gemm_factor():
...
@@ -550,21 +553,6 @@ def test_gemm_factor():
assert
[(
1.0
,
X
),
(
1.0
,
Y
)]
==
_factor_canonicalized
([(
1.0
,
X
),
(
1.0
,
Y
)])
assert
[(
1.0
,
X
),
(
1.0
,
Y
)]
==
_factor_canonicalized
([(
1.0
,
X
),
(
1.0
,
Y
)])
assert
[(
2.0
,
X
)]
==
_factor_canonicalized
([(
1.0
,
X
),(
1.0
,
X
)])
assert
[(
2.0
,
X
)]
==
_factor_canonicalized
([(
1.0
,
X
),(
1.0
,
X
)])
def
test_upcasting_scalar_nogemv
():
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
v
=
T
.
fvector
(
'v'
)
w
=
T
.
fmatrix
(
'w'
)
t
=
T
.
fvector
(
't'
)
alpha
=
T
.
dscalar
(
'a'
)
rval
=
T
.
dot
(
w
,
v
)
*
alpha
+
t
f
=
theano
.
function
([
w
,
v
,
t
,
alpha
],
rval
)
t
=
f
.
maker
.
env
.
toposort
()
assert
numpy
.
sum
([
isinstance
(
n
.
op
,
Gemv
)
for
n
in
t
])
==
0
theano
.
printing
.
debugprint
(
f
,
print_type
=
True
)
def
test_upcasting_scalar_nogemm
():
def
test_upcasting_scalar_nogemm
():
# Test that the optimization does not crash when the scale has an incorrect
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
# dtype, and forces upcasting of the result
...
@@ -862,22 +850,35 @@ def test_dot_w_self():
...
@@ -862,22 +850,35 @@ def test_dot_w_self():
## Tests for Gemv
## Tests for Gemv
###############################################################################
###############################################################################
class
TestGemv
(
TestCase
):
class
TestGemv
(
TestCase
,
unittest_tools
.
TestOptimizationMixin
):
def
test_dot_vv
(
self
):
''' Currently we generate a gemv for that case'''
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
v
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
'float32'
))
w
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
'float32'
))
f
=
theano
.
function
([],
theano
.
dot
(
v
,
w
),
mode
=
mode_blas_opt
)
# Assert that the dot was optimized somehow
self
.
assertFunctionContains0
(
f
,
T
.
dot
)
self
.
assertFunctionContains1
(
f
,
Gemv
(
False
))
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
v
.
get_value
(),
w
.
get_value
()))
def
test_dot_vm
(
self
):
def
test_dot_vm
(
self
):
''' Test vector dot matrix '''
''' Test vector dot matrix '''
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
v
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
'float32'
))
v
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
'float32'
))
m
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
3
)),
dtype
=
'float32'
))
m
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
3
)),
dtype
=
'float32'
))
f
=
theano
.
function
([],
theano
.
dot
(
v
,
m
),
mode
=
mode_blas_opt
)
f
=
theano
.
function
([],
theano
.
dot
(
v
,
m
),
mode
=
mode_blas_opt
)
# Assert that the dot was optimized somehow
self
.
assertFunctionContains0
(
f
,
T
.
dot
)
self
.
assertFunctionContains1
(
f
,
Gemv
(
True
))
# Assert they produce the same output
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
v
.
get_value
(),
m
.
get_value
()))
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
v
.
get_value
(),
m
.
get_value
()))
# Assert that the dot was optimized somehow
assert
sum
([
isinstance
(
node
.
op
,
T
.
Dot
)
for
node
in
f
.
maker
.
env
.
toposort
()
])
==
0
assert
sum
([
isinstance
(
node
.
op
,
T
.
blas
.
Dot22
)
for
node
in
f
.
maker
.
env
.
toposort
()
])
==
1
def
test_dot_mv
(
self
):
def
test_dot_mv
(
self
):
''' Test matrix dot vector '''
''' Test matrix dot vector '''
...
@@ -885,17 +886,15 @@ class TestGemv(TestCase):
...
@@ -885,17 +886,15 @@ class TestGemv(TestCase):
v
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
'float32'
))
v
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,)),
dtype
=
'float32'
))
m
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
3
,
2
)),
m
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
3
,
2
)),
dtype
=
'float32'
))
dtype
=
'float32'
))
f
=
theano
.
function
([],
theano
.
dot
(
m
,
v
),
mode
=
mode_blas_opt
)
f
=
theano
.
function
([],
theano
.
dot
(
m
,
v
),
mode
=
mode_blas_opt
)
# Assert that the dot was optimized somehow
self
.
assertFunctionContains0
(
f
,
T
.
dot
)
self
.
assertFunctionContains1
(
f
,
Gemv
(
True
))
# Assert they produce the same output
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
m
.
get_value
(),
v
.
get_value
()))
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
m
.
get_value
(),
v
.
get_value
()))
# Assert that the dot was optimized somehow
assert
sum
([
isinstance
(
node
.
op
,
T
.
Dot
)
for
node
in
f
.
maker
.
env
.
toposort
()
])
==
0
assert
sum
([
isinstance
(
node
.
op
,
T
.
blas
.
Dot22
)
for
node
in
f
.
maker
.
env
.
toposort
()
])
==
1
@staticmethod
@staticmethod
def
t_gemv1
(
m_shp
):
def
t_gemv1
(
m_shp
):
''' test vector2+dot(matrix,vector1) '''
''' test vector2+dot(matrix,vector1) '''
...
@@ -1017,6 +1016,8 @@ def matrixmultiply(a, b):
...
@@ -1017,6 +1016,8 @@ def matrixmultiply(a, b):
class
BaseGemv
(
object
):
class
BaseGemv
(
object
):
mode
=
mode_blas_opt
# can be overridden with self.mode
def
get_data
(
self
,
x_stride
=
1
,
y_stride
=
1
):
def
get_data
(
self
,
x_stride
=
1
,
y_stride
=
1
):
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
mult
=
array
(
1
,
dtype
=
self
.
dtype
)
mult
=
array
(
1
,
dtype
=
self
.
dtype
)
...
@@ -1035,10 +1036,10 @@ class BaseGemv(object):
...
@@ -1035,10 +1036,10 @@ class BaseGemv(object):
oy
=
alpha
*
T
.
dot
(
a
,
x
)
+
beta
*
y
oy
=
alpha
*
T
.
dot
(
a
,
x
)
+
beta
*
y
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
topo
=
oy_func
.
maker
.
env
.
toposort
()
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
1
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv
)
oy_val
=
oy_func
()
oy_val
=
oy_func
()
...
@@ -1055,22 +1056,9 @@ class BaseGemv(object):
...
@@ -1055,22 +1056,9 @@ class BaseGemv(object):
oy
=
T
.
dot
(
a
,
x
)
oy
=
T
.
dot
(
a
,
x
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv_inplace
)
# The only op in the graph is a dot.
# In the gemm case, we create a dot22 for that case
# There is no dot21.
# Creating one is not useful as this is not faster(in fact it would be slower!
# as more code would be in python, numpy.dot will call gemv itself)
# See ticket 594
"""
>>> t0=time.time();x=scipy.linalg.blas.fblas.dgemv(1,a.T,b,1,z.T);t1=time.time();print t1-t0
0.00192999839783
>>> t0=time.time();x=numpy.dot(a,b);t1=time.time();print t1-t0
0.00158381462097
"""
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
0
oy_v
=
oy_func
()
oy_v
=
oy_func
()
assert_array_almost_equal
(
desired_oy
,
oy_v
)
assert_array_almost_equal
(
desired_oy
,
oy_v
)
...
@@ -1085,10 +1073,9 @@ class BaseGemv(object):
...
@@ -1085,10 +1073,9 @@ class BaseGemv(object):
oy
=
alpha
*
T
.
dot
(
a
.
T
,
x
)
+
beta
*
y
oy
=
alpha
*
T
.
dot
(
a
.
T
,
x
)
+
beta
*
y
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv
)
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
1
oy_v
=
oy_func
()
oy_v
=
oy_func
()
assert_array_almost_equal
(
desired_oy
,
oy_v
)
assert_array_almost_equal
(
desired_oy
,
oy_v
)
...
@@ -1102,10 +1089,9 @@ class BaseGemv(object):
...
@@ -1102,10 +1089,9 @@ class BaseGemv(object):
oy
=
alpha
*
T
.
dot
(
a
,
x
[::
2
])
+
beta
*
y
oy
=
alpha
*
T
.
dot
(
a
,
x
[::
2
])
+
beta
*
y
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv
)
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
1
oy_v
=
oy_func
()
oy_v
=
oy_func
()
assert_array_almost_equal
(
desired_oy
,
oy_v
)
assert_array_almost_equal
(
desired_oy
,
oy_v
)
...
@@ -1119,10 +1105,9 @@ class BaseGemv(object):
...
@@ -1119,10 +1105,9 @@ class BaseGemv(object):
oy
=
alpha
*
T
.
dot
(
a
.
T
,
x
[::
2
])
+
beta
*
y
oy
=
alpha
*
T
.
dot
(
a
.
T
,
x
[::
2
])
+
beta
*
y
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv
)
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
1
oy_v
=
oy_func
()
oy_v
=
oy_func
()
assert_array_almost_equal
(
desired_oy
,
oy_v
)
assert_array_almost_equal
(
desired_oy
,
oy_v
)
...
@@ -1136,10 +1121,9 @@ class BaseGemv(object):
...
@@ -1136,10 +1121,9 @@ class BaseGemv(object):
oy
=
alpha
*
T
.
dot
(
a
,
x
)
+
beta
*
y
[::
2
]
oy
=
alpha
*
T
.
dot
(
a
,
x
)
+
beta
*
y
[::
2
]
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv
)
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
1
oy_v
=
oy_func
()
oy_v
=
oy_func
()
assert_array_almost_equal
(
desired_oy
,
oy_v
)
assert_array_almost_equal
(
desired_oy
,
oy_v
)
...
@@ -1153,21 +1137,56 @@ class BaseGemv(object):
...
@@ -1153,21 +1137,56 @@ class BaseGemv(object):
oy
=
alpha
*
T
.
dot
(
a
.
T
,
x
)
+
beta
*
y
[::
2
]
oy
=
alpha
*
T
.
dot
(
a
.
T
,
x
)
+
beta
*
y
[::
2
]
oy_func
=
theano
.
function
([],
oy
,
mode
=
mode_blas_opt
)
oy_func
=
theano
.
function
([],
oy
,
mode
=
self
.
mode
)
topo
=
oy_func
.
maker
.
env
.
toposort
()
self
.
assertFunctionContains1
(
oy_func
,
self
.
gemv
)
assert
sum
([
isinstance
(
node
.
op
,
theano
.
tensor
.
blas
.
Gemv
)
for
node
in
topo
])
==
1
oy_v
=
oy_func
()
oy_v
=
oy_func
()
assert_array_almost_equal
(
desired_oy
,
oy_v
)
assert_array_almost_equal
(
desired_oy
,
oy_v
)
def
test_upcasting_scalar_nogemv
(
self
):
# Test that the optimization does not crash when the scale has
# an incorrect dtype, and forces upcasting of the result
# We put this test in this class to test it on the gpu too.
vs
=
self
.
get_data
()
alpha_v
,
beta_v
,
a_v
,
x_v
,
y_v
=
vs
alpha_v
=
alpha_v
.
astype
(
"float64"
)
a_v
=
a_v
.
astype
(
"float32"
)
x_v
=
x_v
.
astype
(
"float32"
)
y_v
=
y_v
.
astype
(
"float32"
)
alpha
=
T
.
dscalar
(
'a'
)
a
=
T
.
fmatrix
(
'w'
)
x
=
T
.
fvector
(
'v'
)
y
=
T
.
fvector
(
't'
)
rval
=
T
.
dot
(
a
,
x
)
*
alpha
+
y
f
=
theano
.
function
([
a
,
x
,
y
,
alpha
],
rval
,
mode
=
self
.
mode
)
# this function is currently optimized so that the gemv is
# done inplace on a temporarily allocated-buffer, which is
# then scaled by alpha and to t with a fused elemwise.
n_gemvs
=
0
#theano.printing.debugprint(f, print_type=True)
for
node
in
f
.
maker
.
env
.
toposort
():
if
node
.
op
==
self
.
gemv_inplace
:
n_gemvs
+=
1
assert
node
.
outputs
[
0
]
.
dtype
==
'float32'
assert
n_gemvs
==
1
,
n_gemvs
self
.
assertFunctionContains1
(
f
,
self
.
gemv_inplace
)
f
(
a_v
,
x_v
,
y_v
,
alpha_v
)
class
TestSgemv
(
TestCase
,
BaseGemv
):
class
TestSgemv
(
TestCase
,
BaseGemv
,
unittest_tools
.
TestOptimizationMixin
):
dtype
=
float32
dtype
=
float32
gemv
=
theano
.
tensor
.
blas
.
gemv_no_inplace
gemv_inplace
=
theano
.
tensor
.
blas
.
gemv_inplace
class
TestDgemv
(
TestCase
,
BaseGemv
):
class
TestDgemv
(
TestCase
,
BaseGemv
,
unittest_tools
.
TestOptimizationMixin
):
dtype
=
float64
dtype
=
float64
gemv
=
theano
.
tensor
.
blas
.
gemv_no_inplace
gemv_inplace
=
theano
.
tensor
.
blas
.
gemv_inplace
#The optimization to put Gemv don't work for complex type for now.
#The optimization to put Gemv don't work for complex type for now.
# See ticket 653.
# See ticket 653.
...
@@ -1252,178 +1271,49 @@ class TestGer_make_node(TestCase):
...
@@ -1252,178 +1271,49 @@ class TestGer_make_node(TestCase):
self
.
assertRaises
(
TypeError
,
ger
,
self
.
cm
,
self
.
fa
,
self
.
fv
,
self
.
dv_2
)
self
.
assertRaises
(
TypeError
,
ger
,
self
.
cm
,
self
.
fa
,
self
.
fv
,
self
.
dv_2
)
self
.
assertRaises
(
TypeError
,
ger
,
self
.
cm
,
self
.
fa
,
self
.
fv
,
self
.
zv_2
)
self
.
assertRaises
(
TypeError
,
ger
,
self
.
cm
,
self
.
fa
,
self
.
fv
,
self
.
zv_2
)
# TODO: refactor this into some place where all OpTesters could use it.
# This object name should not start with Test.
# Otherwise nosetests will execute it!
class
T_OpContractMixin
(
object
):
# self.ops should be a list of instantiations of an Op class to test.
# self.other_op should be an op which is different from every op
other_op
=
T
.
add
def
copy
(
self
,
x
):
return
copy
(
x
)
def
deepcopy
(
self
,
x
):
return
deepcopy
(
x
)
def
clone
(
self
,
op
):
class
TestGer_OpContract
(
TestCase
,
unittest_tools
.
T_OpContractMixin
):
raise
NotImplementedError
(
'return new instance like `op`'
)
def
test_eq
(
self
):
for
i
,
op_i
in
enumerate
(
self
.
ops
):
assert
op_i
==
op_i
assert
op_i
==
self
.
copy
(
op_i
)
assert
op_i
==
self
.
deepcopy
(
op_i
)
assert
op_i
==
self
.
clone
(
op_i
)
assert
op_i
!=
self
.
other_op
for
j
,
op_j
in
enumerate
(
self
.
ops
):
if
i
==
j
:
continue
assert
op_i
!=
op_j
def
test_hash
(
self
):
for
i
,
op_i
in
enumerate
(
self
.
ops
):
h_i
=
hash
(
op_i
)
assert
h_i
==
hash
(
op_i
)
assert
h_i
==
hash
(
self
.
copy
(
op_i
))
assert
h_i
==
hash
(
self
.
deepcopy
(
op_i
))
assert
h_i
==
hash
(
self
.
clone
(
op_i
))
assert
h_i
!=
hash
(
self
.
other_op
)
for
j
,
op_j
in
enumerate
(
self
.
ops
):
if
i
==
j
:
continue
assert
op_i
!=
hash
(
op_j
)
def
test_name
(
self
):
for
op
in
self
.
ops
:
s
=
str
(
op
)
# show that str works
assert
s
# names should not be empty
class
TestGer_OpContract
(
TestCase
,
T_OpContractMixin
):
#TODO: These tests could be factored into a generic Op-testing base-class
def
setUp
(
self
):
def
setUp
(
self
):
self
.
ops
=
[
ger
,
ger_destructive
]
self
.
ops
=
[
ger
,
ger_destructive
]
def
clone
(
self
,
op
):
def
clone
(
self
,
op
):
return
Ger
(
op
.
destructive
)
return
Ger
(
op
.
destructive
)
class
TestGer_make_thunk
(
TestCase
):
def
setUp
(
self
):
self
.
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
def
given_dtype
(
self
,
dtype
,
M
,
N
):
sA
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
sa
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
sx
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
sy
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
sZ
=
ger
(
sA
,
sa
,
sx
,
sy
)
node
=
sZ
.
owner
storage_map
=
{
sA
:[
None
],
sa
:[
None
],
sx
:[
None
],
sy
:[
None
],
sZ
:[
None
]}
thunk
=
ger
.
make_thunk
(
node
,
storage_map
,
compute_map
=
{},
no_recycling
=
[])
# non-standard for make_thunk to receive node.op != self,
# but works for now.
thunk_d
=
ger_destructive
.
make_thunk
(
node
,
storage_map
,
compute_map
=
{},
no_recycling
=
[])
def
rand
(
*
shape
):
return
numpy
.
asarray
(
1
+
self
.
rng
.
rand
(
*
shape
),
dtype
=
dtype
)
storage_map
[
sA
][
0
]
=
rand
(
M
,
N
)
storage_map
[
sa
][
0
]
=
rand
()
storage_map
[
sx
][
0
]
=
rand
(
M
)
storage_map
[
sy
][
0
]
=
rand
(
N
)
storage_map_copy
=
dict
([(
k
,[
deepcopy
(
v
[
0
])])
for
k
,
v
in
storage_map
.
items
()])
# TODO: do some DebugMode-type verifications here
# if this can be refactored into a Mixin that does the DebugMode
# stuff on just one thunk at a time. Do it in the style of
# TestOpContractMixin?
# - Compare with Elemwise testers
thunk
()
assert
numpy
.
all
(
storage_map
[
sZ
][
0
]
==
storage_map
[
sA
][
0
]
+
storage_map
[
sa
][
0
]
*
numpy
.
outer
(
storage_map
[
sx
][
0
],
storage_map
[
sy
][
0
]))
assert
storage_map
[
sZ
][
0
]
.
dtype
==
dtype
assert
storage_map
[
sZ
][
0
]
.
shape
==
(
M
,
N
)
thunk_d
()
assert
numpy
.
all
(
storage_map
[
sZ
][
0
]
!=
storage_map
[
sA
][
0
]
+
storage_map
[
sa
][
0
]
*
numpy
.
outer
(
storage_map
[
sx
][
0
],
storage_map
[
sy
][
0
]))
assert
numpy
.
all
(
storage_map
[
sZ
][
0
]
==
storage_map_copy
[
sA
][
0
]
+
storage_map
[
sa
][
0
]
*
numpy
.
outer
(
storage_map
[
sx
][
0
],
storage_map
[
sy
][
0
]))
assert
storage_map
[
sZ
][
0
]
.
dtype
==
dtype
assert
storage_map
[
sZ
][
0
]
.
shape
==
(
M
,
N
)
def
test_f32_0_0
(
self
):
return
self
.
given_dtype
(
'float32'
,
0
,
0
)
def
test_f32_1_0
(
self
):
return
self
.
given_dtype
(
'float32'
,
1
,
0
)
def
test_f32_0_1
(
self
):
return
self
.
given_dtype
(
'float32'
,
0
,
1
)
def
test_f32_1_1
(
self
):
return
self
.
given_dtype
(
'float32'
,
1
,
1
)
def
test_f32_4_4
(
self
):
return
self
.
given_dtype
(
'float32'
,
4
,
4
)
def
test_f64_4_5
(
self
):
return
self
.
given_dtype
(
'float64'
,
4
,
5
)
def
test_c64_7_1
(
self
):
return
self
.
given_dtype
(
'complex64'
,
7
,
1
)
def
test_c128_1_9
(
self
):
return
self
.
given_dtype
(
'complex128'
,
1
,
9
)
# TODO: Refactor and add to this base class as we refactor test code.
class
TestOptimizationMixin
(
object
):
def
assertFunctionContains
(
self
,
f
,
op
,
min
=
1
,
max
=
sys
.
maxint
):
toposort
=
f
.
maker
.
env
.
toposort
()
matches
=
[
node
for
node
in
toposort
if
node
.
op
==
op
]
assert
(
min
<=
len
(
matches
)
<=
max
),
toposort
def
assertFunctionContains0
(
self
,
f
,
op
):
class
TestGer
(
TestCase
,
unittest_tools
.
TestOptimizationMixin
):
return
self
.
assertFunctionContains
(
f
,
op
,
min
=
0
,
max
=
0
)
def
assertFunctionContains1
(
self
,
f
,
op
):
return
self
.
assertFunctionContains
(
f
,
op
,
min
=
1
,
max
=
1
)
def
assertFunctionContainsN
(
self
,
f
,
op
,
N
):
return
self
.
assertFunctionContains
(
f
,
op
,
min
=
N
,
max
=
N
)
def
SkipTest
(
self
):
raise
Exception
(
'how do I skip this test properly?'
)
class
TestGer_local_gemm_to_ger
(
TestCase
,
TestOptimizationMixin
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
mode
=
theano
.
compile
.
get_default_mode
()
.
including
(
'fast_run'
)
self
.
mode
=
theano
.
compile
.
get_default_mode
()
.
including
(
'fast_run'
)
self
.
mode
=
self
.
mode
.
excluding
(
'c_blas'
,
'scipy_blas'
)
dtype
=
self
.
dtype
=
'float64'
# optimization isn't dtype-dependent
dtype
=
self
.
dtype
=
'float64'
# optimization isn't dtype-dependent
self
.
A
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
A
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
a
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
self
.
a
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
self
.
x
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
x
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
y
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
y
=
T
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
origval
=
theano
.
tensor
.
blas_scipy
.
optimizations_enabled
self
.
ger
=
ger
theano
.
tensor
.
blas_scipy
.
optimizations_enabled
=
False
self
.
ger_destructive
=
ger_destructive
self
.
gemm
=
gemm_no_inplace
def
tearDown
(
self
):
theano
.
tensor
.
blas_scipy
.
optimizations_enabled
=
self
.
origval
def
function
(
self
,
inputs
,
outputs
):
def
function
(
self
,
inputs
,
outputs
,
updates
=
{}
):
return
theano
.
function
(
inputs
,
outputs
,
self
.
mode
)
return
theano
.
function
(
inputs
,
outputs
,
self
.
mode
,
updates
=
updates
)
def
b
(
self
,
bval
):
def
b
(
self
,
bval
):
return
T
.
as_tensor_variable
(
numpy
.
asarray
(
bval
,
dtype
=
self
.
dtype
))
return
T
.
as_tensor_variable
(
numpy
.
asarray
(
bval
,
dtype
=
self
.
dtype
))
def
test_b_0_triggers_ger
(
self
):
def
test_b_0_triggers_ger
(
self
):
""" test local_gemm_to_ger opt"""
assert
T
.
blas
.
local_gemm_to_ger
.
transform
(
assert
T
.
blas
.
local_gemm_to_ger
.
transform
(
gemm_no_inplace
(
gemm_no_inplace
(
self
.
A
,
self
.
a
,
self
.
x
.
dimshuffle
(
0
,
'x'
),
self
.
A
,
self
.
a
,
self
.
x
.
dimshuffle
(
0
,
'x'
),
self
.
y
.
dimshuffle
(
'x'
,
0
),
self
.
b
(
0
))
.
owner
)
self
.
y
.
dimshuffle
(
'x'
,
0
),
self
.
b
(
0
))
.
owner
)
def
test_b_1_triggers_ger
(
self
):
def
test_b_1_triggers_ger
(
self
):
""" test local_gemm_to_ger opt"""
assert
T
.
blas
.
local_gemm_to_ger
.
transform
(
assert
T
.
blas
.
local_gemm_to_ger
.
transform
(
gemm_no_inplace
(
gemm_no_inplace
(
self
.
A
,
self
.
a
,
self
.
x
.
dimshuffle
(
0
,
'x'
),
self
.
A
,
self
.
a
,
self
.
x
.
dimshuffle
(
0
,
'x'
),
self
.
y
.
dimshuffle
(
'x'
,
0
),
self
.
b
(
1
))
.
owner
)
self
.
y
.
dimshuffle
(
'x'
,
0
),
self
.
b
(
1
))
.
owner
)
def
test_b_other_does_not_triggers_ger
(
self
):
def
test_b_other_does_not_triggers_ger
(
self
):
""" test local_gemm_to_ger opt"""
assert
not
T
.
blas
.
local_gemm_to_ger
.
transform
(
assert
not
T
.
blas
.
local_gemm_to_ger
.
transform
(
gemm_no_inplace
(
gemm_no_inplace
(
self
.
A
,
self
.
a
,
self
.
x
.
dimshuffle
(
0
,
'x'
),
self
.
A
,
self
.
a
,
self
.
x
.
dimshuffle
(
0
,
'x'
),
...
@@ -1431,19 +1321,77 @@ class TestGer_local_gemm_to_ger(TestCase, TestOptimizationMixin):
...
@@ -1431,19 +1321,77 @@ class TestGer_local_gemm_to_ger(TestCase, TestOptimizationMixin):
def
test_outer
(
self
):
def
test_outer
(
self
):
f
=
self
.
function
([
self
.
x
,
self
.
y
],
T
.
outer
(
self
.
x
,
self
.
y
))
f
=
self
.
function
([
self
.
x
,
self
.
y
],
T
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
ger_destructive
)
self
.
assertFunctionContains
(
f
,
self
.
ger_destructive
)
f
(
numpy
.
random
.
rand
(
5
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
4
)
.
astype
(
self
.
dtype
))
def
test_A_plus_outer
(
self
):
def
test_A_plus_outer
(
self
):
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
T
.
outer
(
self
.
x
,
self
.
y
))
self
.
A
+
T
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
ger
)
self
.
assertFunctionContains
(
f
,
self
.
ger
)
f
(
numpy
.
random
.
rand
(
5
,
4
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
5
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
4
)
.
astype
(
self
.
dtype
))
def
test_A_plus_scaled_outer
(
self
):
def
test_A_plus_scaled_outer
(
self
):
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
0.1
*
T
.
outer
(
self
.
x
,
self
.
y
))
self
.
A
+
0.1
*
T
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
ger
)
self
.
assertFunctionContains
(
f
,
self
.
ger
)
f
(
numpy
.
random
.
rand
(
5
,
4
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
5
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
4
)
.
astype
(
self
.
dtype
))
def
test_scaled_A_plus_scaled_outer
(
self
):
def
test_scaled_A_plus_scaled_outer
(
self
):
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
0.2
*
self
.
A
+
0.1
*
T
.
outer
(
self
.
x
,
self
.
y
))
numpy
.
asarray
(
0.2
,
self
.
dtype
)
*
self
.
A
+
self
.
assertFunctionContains
(
f
,
gemm_no_inplace
)
numpy
.
asarray
(
0.1
,
self
.
dtype
)
*
T
.
outer
(
self
.
x
,
self
.
y
))
# Why gemm? This make the graph simpler did we test that it
# make it faster?
self
.
assertFunctionContains
(
f
,
self
.
gemm
)
f
(
numpy
.
random
.
rand
(
5
,
4
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
5
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
4
)
.
astype
(
self
.
dtype
))
def
given_dtype
(
self
,
dtype
,
M
,
N
):
""" test corner case shape and dtype"""
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
0.1
*
T
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
self
.
ger
)
f
(
numpy
.
random
.
rand
(
M
,
N
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
M
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
N
)
.
astype
(
self
.
dtype
))
def
test_f32_0_0
(
self
):
return
self
.
given_dtype
(
'float32'
,
0
,
0
)
def
test_f32_1_0
(
self
):
return
self
.
given_dtype
(
'float32'
,
1
,
0
)
def
test_f32_0_1
(
self
):
return
self
.
given_dtype
(
'float32'
,
0
,
1
)
def
test_f32_1_1
(
self
):
return
self
.
given_dtype
(
'float32'
,
1
,
1
)
def
test_f32_4_4
(
self
):
return
self
.
given_dtype
(
'float32'
,
4
,
4
)
def
test_f64_4_5
(
self
):
return
self
.
given_dtype
(
'float64'
,
4
,
5
)
def
test_c64_7_1
(
self
):
return
self
.
given_dtype
(
'complex64'
,
7
,
1
)
def
test_c128_1_9
(
self
):
return
self
.
given_dtype
(
'complex128'
,
1
,
9
)
def
test_inplace
(
self
):
A
=
theano
.
shared
(
numpy
.
random
.
rand
(
4
,
5
)
.
astype
(
self
.
dtype
))
f
=
self
.
function
([
self
.
x
,
self
.
y
],
[],
updates
=
{
A
:
A
+
T
.
constant
(
0.1
,
dtype
=
self
.
dtype
)
*
T
.
outer
(
self
.
x
,
self
.
y
)})
self
.
assertFunctionContains
(
f
,
self
.
ger_destructive
)
f
(
numpy
.
random
.
rand
(
4
)
.
astype
(
self
.
dtype
),
numpy
.
random
.
rand
(
5
)
.
astype
(
self
.
dtype
))
theano/tensor/tests/test_blas_c.py
0 → 100644
浏览文件 @
f2743791
import
sys
import
numpy
import
theano
import
theano.tensor
as
tensor
from
theano.tensor.blas_c
import
CGer
from
theano.tensor.blas_scipy
import
ScipyGer
from
theano.tensor.blas
import
Ger
from
theano.tensor.blas_c
import
CGemv
from
theano.tensor.blas_scipy
import
ScipyGer
from
theano.tensor.blas
import
Gemv
from
theano.tests
import
unittest_tools
from
theano.tests.unittest_tools
import
TestOptimizationMixin
from
test_blas
import
TestCase
from
test_blas
import
BaseGemv
mode_blas_opt
=
theano
.
compile
.
get_default_mode
()
.
including
(
'BlasOpt'
,
'specialize'
,
'InplaceBlasOpt'
,
'c_blas'
)
class
TestCGer
(
TestCase
,
TestOptimizationMixin
):
def
setUp
(
self
,
dtype
=
'float64'
):
self
.
dtype
=
dtype
self
.
mode
=
theano
.
compile
.
get_default_mode
()
.
including
(
'fast_run'
)
self
.
A
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
a
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
self
.
x
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
y
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
Aval
=
numpy
.
ones
((
2
,
3
),
dtype
=
dtype
)
self
.
xval
=
numpy
.
asarray
([
1
,
2
],
dtype
=
dtype
)
self
.
yval
=
numpy
.
asarray
([
1.5
,
2.7
,
3.9
],
dtype
=
dtype
)
def
function
(
self
,
inputs
,
outputs
):
return
theano
.
function
(
inputs
,
outputs
,
mode
=
self
.
mode
,
#allow_inplace=True,
)
def
run_f
(
self
,
f
):
return
f
(
self
.
Aval
,
self
.
xval
,
self
.
yval
)
def
b
(
self
,
bval
):
return
tensor
.
as_tensor_variable
(
numpy
.
asarray
(
bval
,
dtype
=
self
.
dtype
))
def
test_eq
(
self
):
self
.
assert_
(
CGer
(
True
)
==
CGer
(
True
))
self
.
assert_
(
CGer
(
False
)
==
CGer
(
False
))
self
.
assert_
(
CGer
(
False
)
!=
CGer
(
True
))
self
.
assert_
(
CGer
(
True
)
!=
ScipyGer
(
True
))
self
.
assert_
(
CGer
(
False
)
!=
ScipyGer
(
False
))
self
.
assert_
(
CGer
(
True
)
!=
Ger
(
True
))
self
.
assert_
(
CGer
(
False
)
!=
Ger
(
False
))
# assert that eq works for non-CGer instances
self
.
assert_
(
CGer
(
False
)
!=
None
)
self
.
assert_
(
CGer
(
True
)
!=
None
)
def
test_hash
(
self
):
self
.
assert_
(
hash
(
CGer
(
True
))
==
hash
(
CGer
(
True
)))
self
.
assert_
(
hash
(
CGer
(
False
))
==
hash
(
CGer
(
False
)))
self
.
assert_
(
hash
(
CGer
(
False
))
!=
hash
(
CGer
(
True
)))
def
test_optimization_pipeline
(
self
):
f
=
self
.
function
([
self
.
x
,
self
.
y
],
tensor
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
CGer
(
destructive
=
True
))
f
(
self
.
xval
,
self
.
yval
)
#DebugMode tests correctness
def
test_optimization_pipeline_float
(
self
):
self
.
setUp
(
'float32'
)
f
=
self
.
function
([
self
.
x
,
self
.
y
],
tensor
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
CGer
(
destructive
=
True
))
f
(
self
.
xval
,
self
.
yval
)
#DebugMode tests correctness
def
test_int_fails
(
self
):
self
.
setUp
(
'int32'
)
f
=
self
.
function
([
self
.
x
,
self
.
y
],
tensor
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains0
(
f
,
CGer
(
destructive
=
True
))
self
.
assertFunctionContains0
(
f
,
CGer
(
destructive
=
False
))
def
test_A_plus_outer
(
self
):
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
tensor
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
CGer
(
destructive
=
False
))
self
.
run_f
(
f
)
#DebugMode tests correctness
def
test_A_plus_scaled_outer
(
self
):
f
=
self
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
self
.
A
+
0.1
*
tensor
.
outer
(
self
.
x
,
self
.
y
))
self
.
assertFunctionContains
(
f
,
CGer
(
destructive
=
False
))
self
.
run_f
(
f
)
#DebugMode tests correctness
class
TestCGemv
(
TestCase
,
TestOptimizationMixin
):
"""
Tests of CGemv specifically.
Generic tests of Gemv-compatibility, including both dtypes are done below in
TestCGemvFloat32 and TestCGemvFloat64
"""
def
setUp
(
self
,
dtype
=
'float64'
):
self
.
dtype
=
dtype
self
.
mode
=
theano
.
compile
.
get_default_mode
()
.
including
(
'fast_run'
)
# matrix
self
.
A
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
Aval
=
numpy
.
ones
((
2
,
3
),
dtype
=
dtype
)
# vector
self
.
x
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
y
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,))
self
.
xval
=
numpy
.
asarray
([
1
,
2
],
dtype
=
dtype
)
self
.
yval
=
numpy
.
asarray
([
1.5
,
2.7
,
3.9
],
dtype
=
dtype
)
# scalar
self
.
a
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
def
test_optimizations_vm
(
self
):
''' Test vector dot matrix '''
f
=
theano
.
function
([
self
.
x
,
self
.
A
],
theano
.
dot
(
self
.
x
,
self
.
A
),
mode
=
self
.
mode
)
# Assert that the dot was optimized somehow
self
.
assertFunctionContains0
(
f
,
tensor
.
dot
)
self
.
assertFunctionContains1
(
f
,
CGemv
(
True
))
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(
self
.
xval
,
self
.
Aval
),
numpy
.
dot
(
self
.
xval
,
self
.
Aval
))
def
test_optimizations_mv
(
self
):
''' Test matrix dot vector '''
f
=
theano
.
function
([
self
.
A
,
self
.
y
],
theano
.
dot
(
self
.
A
,
self
.
y
),
mode
=
self
.
mode
)
# Assert that the dot was optimized somehow
self
.
assertFunctionContains0
(
f
,
tensor
.
dot
)
self
.
assertFunctionContains1
(
f
,
CGemv
(
True
))
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(
self
.
Aval
,
self
.
yval
),
numpy
.
dot
(
self
.
Aval
,
self
.
yval
))
def
t_gemv1
(
self
,
m_shp
):
''' test vector2 + dot(matrix, vector1) '''
rng
=
numpy
.
random
.
RandomState
(
unittest_tools
.
fetch_seed
())
v1
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
(
m_shp
[
1
],)),
dtype
=
'float32'
))
v2_orig
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
m_shp
[
0
],)),
dtype
=
'float32'
)
v2
=
theano
.
shared
(
v2_orig
)
m
=
theano
.
shared
(
numpy
.
array
(
rng
.
uniform
(
size
=
m_shp
),
dtype
=
'float32'
))
f
=
theano
.
function
([],
v2
+
tensor
.
dot
(
m
,
v1
),
mode
=
self
.
mode
)
# Assert they produce the same output
assert
numpy
.
allclose
(
f
(),
numpy
.
dot
(
m
.
get_value
(),
v1
.
get_value
())
+
v2_orig
)
topo
=
[
n
.
op
for
n
in
f
.
maker
.
env
.
toposort
()]
assert
topo
==
[
CGemv
(
inplace
=
False
)],
topo
#test the inplace version
f
=
theano
.
function
([],
[],
updates
=
{
v2
:
v2
+
theano
.
dot
(
m
,
v1
)},
mode
=
self
.
mode
)
# Assert they produce the same output
f
()
assert
numpy
.
allclose
(
v2
.
get_value
(),
numpy
.
dot
(
m
.
get_value
(),
v1
.
get_value
())
+
v2_orig
)
topo
=
[
n
.
op
for
n
in
f
.
maker
.
env
.
toposort
()]
assert
topo
==
[
CGemv
(
inplace
=
True
)]
def
test_gemv1
(
self
):
self
.
t_gemv1
((
3
,
2
))
self
.
t_gemv1
((
0
,
2
))
self
.
t_gemv1
((
3
,
0
))
self
.
t_gemv1
((
0
,
0
))
def
test_gemv_dimensions
(
self
,
dtype
=
'float32'
):
alpha
=
theano
.
shared
(
theano
.
_asarray
(
1.0
,
dtype
=
dtype
),
name
=
'alpha'
)
beta
=
theano
.
shared
(
theano
.
_asarray
(
1.0
,
dtype
=
dtype
),
name
=
'beta'
)
z
=
beta
*
self
.
y
+
alpha
*
tensor
.
dot
(
self
.
A
,
self
.
x
)
f
=
theano
.
function
([
self
.
A
,
self
.
x
,
self
.
y
],
z
,
mode
=
self
.
mode
)
# Matrix value
A_val
=
numpy
.
ones
((
5
,
3
),
dtype
=
dtype
)
# Different vector length
ones_3
=
numpy
.
ones
(
3
,
dtype
=
dtype
)
ones_4
=
numpy
.
ones
(
4
,
dtype
=
dtype
)
ones_5
=
numpy
.
ones
(
5
,
dtype
=
dtype
)
ones_6
=
numpy
.
ones
(
6
,
dtype
=
dtype
)
f
(
A_val
,
ones_3
,
ones_5
)
self
.
assertRaises
(
ValueError
,
f
,
A_val
,
ones_4
,
ones_5
)
self
.
assertRaises
(
ValueError
,
f
,
A_val
,
ones_3
,
ones_6
)
self
.
assertRaises
(
ValueError
,
f
,
A_val
,
ones_4
,
ones_6
)
class
TestCGemvFloat32
(
TestCase
,
BaseGemv
,
TestOptimizationMixin
):
mode
=
mode_blas_opt
dtype
=
'float32'
gemv
=
CGemv
(
inplace
=
False
)
gemv_inplace
=
CGemv
(
inplace
=
True
)
class
TestCGemvFloat64
(
TestCase
,
BaseGemv
,
TestOptimizationMixin
):
mode
=
mode_blas_opt
dtype
=
'float64'
gemv
=
CGemv
(
inplace
=
False
)
gemv_inplace
=
CGemv
(
inplace
=
True
)
theano/tensor/tests/test_blas_scipy.py
浏览文件 @
f2743791
...
@@ -4,12 +4,15 @@ import theano
...
@@ -4,12 +4,15 @@ import theano
import
theano.tensor
as
tensor
import
theano.tensor
as
tensor
from
theano.tensor.blas_scipy
import
ScipyGer
from
theano.tensor.blas_scipy
import
ScipyGer
from
test_blas
import
TestCase
,
TestOptimizationMixin
,
gemm_no_inplace
from
test_blas
import
TestCase
,
gemm_no_inplace
from
theano.tests.unittest_tools
import
TestOptimizationMixin
class
TestScipyGer
(
TestCase
,
TestOptimizationMixin
):
class
TestScipyGer
(
TestCase
,
TestOptimizationMixin
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
mode
=
theano
.
compile
.
get_default_mode
()
.
including
(
'fast_run'
)
self
.
mode
=
theano
.
compile
.
get_default_mode
()
self
.
mode
=
self
.
mode
.
including
(
'fast_run'
)
self
.
mode
=
self
.
mode
.
excluding
(
'c_blas'
)
# c_blas trumps scipy Ops
dtype
=
self
.
dtype
=
'float64'
# optimization isn't dtype-dependent
dtype
=
self
.
dtype
=
'float64'
# optimization isn't dtype-dependent
self
.
A
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
A
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
(
False
,
False
))
self
.
a
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
self
.
a
=
tensor
.
tensor
(
dtype
=
dtype
,
broadcastable
=
())
...
@@ -18,9 +21,10 @@ class TestScipyGer(TestCase, TestOptimizationMixin):
...
@@ -18,9 +21,10 @@ class TestScipyGer(TestCase, TestOptimizationMixin):
self
.
Aval
=
numpy
.
ones
((
2
,
3
),
dtype
=
dtype
)
self
.
Aval
=
numpy
.
ones
((
2
,
3
),
dtype
=
dtype
)
self
.
xval
=
numpy
.
asarray
([
1
,
2
],
dtype
=
dtype
)
self
.
xval
=
numpy
.
asarray
([
1
,
2
],
dtype
=
dtype
)
self
.
yval
=
numpy
.
asarray
([
1.5
,
2.7
,
3.9
],
dtype
=
dtype
)
self
.
yval
=
numpy
.
asarray
([
1.5
,
2.7
,
3.9
],
dtype
=
dtype
)
if
not
theano
.
tensor
.
blas_scipy
.
optimizations_enabled
:
if
not
theano
.
tensor
.
blas_scipy
.
have_fblas
:
self
.
SkipTest
()
self
.
SkipTest
()
def
function
(
self
,
inputs
,
outputs
):
def
function
(
self
,
inputs
,
outputs
):
return
theano
.
function
(
inputs
,
outputs
,
self
.
mode
)
return
theano
.
function
(
inputs
,
outputs
,
self
.
mode
)
...
...
theano/tests/unittest_tools.py
浏览文件 @
f2743791
from
copy
import
copy
,
deepcopy
import
sys
import
sys
import
numpy
import
numpy
import
theano.tensor
as
T
import
theano.tensor
as
T
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
try
:
from
nose.plugins.skip
import
SkipTest
except
ImportError
:
class
SkipTest
(
Exception
):
"""
Skip this test
"""
AddConfigVar
(
'unittests.rseed'
,
AddConfigVar
(
'unittests.rseed'
,
"Seed to use for randomized unit tests. Special value 'random' means using a seed of None."
,
"Seed to use for randomized unit tests. Special value 'random' means using a seed of None."
,
StrParam
(
666
),
StrParam
(
666
),
in_c_key
=
False
)
in_c_key
=
False
)
def
fetch_seed
(
pseed
=
None
):
def
fetch_seed
(
pseed
=
None
):
"""
"""
Returns the seed to use for running the unit tests.
Returns the seed to use for running the unit tests.
...
@@ -38,6 +48,7 @@ def fetch_seed(pseed=None):
...
@@ -38,6 +48,7 @@ def fetch_seed(pseed=None):
return
seed
return
seed
def
seed_rng
(
pseed
=
None
):
def
seed_rng
(
pseed
=
None
):
"""
"""
Seeds numpy's random number generator with the value returned by fetch_seed.
Seeds numpy's random number generator with the value returned by fetch_seed.
...
@@ -51,6 +62,7 @@ def seed_rng(pseed=None):
...
@@ -51,6 +62,7 @@ def seed_rng(pseed=None):
numpy
.
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
return
seed
return
seed
def
verify_grad
(
op
,
pt
,
n_tests
=
2
,
rng
=
None
,
*
args
,
**
kwargs
):
def
verify_grad
(
op
,
pt
,
n_tests
=
2
,
rng
=
None
,
*
args
,
**
kwargs
):
"""
"""
Wrapper for tensor/basic.py:verify_grad
Wrapper for tensor/basic.py:verify_grad
...
@@ -72,3 +84,67 @@ def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
...
@@ -72,3 +84,67 @@ def verify_grad(op, pt, n_tests=2, rng=None, *args, **kwargs):
# raise
# raise
#
#
verify_grad
.
E_grad
=
T
.
verify_grad
.
E_grad
verify_grad
.
E_grad
=
T
.
verify_grad
.
E_grad
class
TestOptimizationMixin
(
object
):
def
assertFunctionContains
(
self
,
f
,
op
,
min
=
1
,
max
=
sys
.
maxint
):
toposort
=
f
.
maker
.
env
.
toposort
()
matches
=
[
node
for
node
in
toposort
if
node
.
op
==
op
]
assert
(
min
<=
len
(
matches
)
<=
max
),
(
toposort
,
matches
,
str
(
op
),
min
,
max
)
def
assertFunctionContains0
(
self
,
f
,
op
):
return
self
.
assertFunctionContains
(
f
,
op
,
min
=
0
,
max
=
0
)
def
assertFunctionContains1
(
self
,
f
,
op
):
return
self
.
assertFunctionContains
(
f
,
op
,
min
=
1
,
max
=
1
)
def
assertFunctionContainsN
(
self
,
f
,
op
,
N
):
return
self
.
assertFunctionContains
(
f
,
op
,
min
=
N
,
max
=
N
)
def
SkipTest
(
self
,
msg
=
'Skip this test'
):
raise
SkipTest
(
msg
)
# This object name should not start with Test.
# Otherwise nosetests will execute it!
class
T_OpContractMixin
(
object
):
# self.ops should be a list of instantiations of an Op class to test.
# self.other_op should be an op which is different from every op
other_op
=
T
.
add
def
copy
(
self
,
x
):
return
copy
(
x
)
def
deepcopy
(
self
,
x
):
return
deepcopy
(
x
)
def
clone
(
self
,
op
):
raise
NotImplementedError
(
'return new instance like `op`'
)
def
test_eq
(
self
):
for
i
,
op_i
in
enumerate
(
self
.
ops
):
assert
op_i
==
op_i
assert
op_i
==
self
.
copy
(
op_i
)
assert
op_i
==
self
.
deepcopy
(
op_i
)
assert
op_i
==
self
.
clone
(
op_i
)
assert
op_i
!=
self
.
other_op
for
j
,
op_j
in
enumerate
(
self
.
ops
):
if
i
==
j
:
continue
assert
op_i
!=
op_j
def
test_hash
(
self
):
for
i
,
op_i
in
enumerate
(
self
.
ops
):
h_i
=
hash
(
op_i
)
assert
h_i
==
hash
(
op_i
)
assert
h_i
==
hash
(
self
.
copy
(
op_i
))
assert
h_i
==
hash
(
self
.
deepcopy
(
op_i
))
assert
h_i
==
hash
(
self
.
clone
(
op_i
))
assert
h_i
!=
hash
(
self
.
other_op
)
for
j
,
op_j
in
enumerate
(
self
.
ops
):
if
i
==
j
:
continue
assert
op_i
!=
hash
(
op_j
)
def
test_name
(
self
):
for
op
in
self
.
ops
:
s
=
str
(
op
)
# show that str works
assert
s
# names should not be empty
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论