Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f462a1c0
提交
f462a1c0
authored
3月 20, 2013
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1291 from gdesjardins/erfinv_to_erfinvgpu_rebased
make erfinv work on the gpu.
上级
4ecb1ee6
798623ff
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
67 行增加
和
10 行删除
+67
-10
elemwise.py
theano/sandbox/cuda/elemwise.py
+24
-0
opt.py
theano/sandbox/cuda/opt.py
+21
-10
test_opt.py
theano/sandbox/cuda/tests/test_opt.py
+13
-0
basic_scipy.py
theano/scalar/basic_scipy.py
+9
-0
没有找到文件。
theano/sandbox/cuda/elemwise.py
浏览文件 @
f462a1c0
...
...
@@ -9,6 +9,8 @@ import copy, logging, StringIO, sys
import
numpy
from
theano.scalar.basic
import
upgrade_to_float_no_complex
,
complex_types
from
theano.scalar.basic_scipy
import
Erfinv
from
theano
import
Apply
,
Constant
,
Op
,
Type
,
Variable
from
theano
import
gof
,
scalar
,
tensor
...
...
@@ -1021,3 +1023,25 @@ nd_collapse_[i]=0;
#print sio.getvalue()
return
sio
.
getvalue
()
class
ErfinvGPU
(
Erfinv
):
"""
Provides a c-code implementation of the inverse error function for GPU.
Note: We do not add this c_code to theano.scalar.basic_scipy.Erfinv, as we
currently rely on Nvidia's cublas library to provide the erfinv
c-implementation (which requires different c_headers). As it stands,
theano.scalar.basic_scipy.Erfinv does not have c_code as scipy does not
export the required C function
"""
def
c_headers
(
self
):
return
[
'math_functions.h'
,
'cublas_v2.h'
]
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
x
,
=
inp
z
,
=
out
if
node
.
inputs
[
0
]
.
type
in
complex_types
:
raise
NotImplementedError
(
'type not supported'
,
type
)
return
"
%(z)
s = erfinv(
%(x)
s);"
%
locals
()
erfinv_gpu
=
ErfinvGPU
(
upgrade_to_float_no_complex
,
name
=
'erfinv_gpu'
)
theano/sandbox/cuda/opt.py
浏览文件 @
f462a1c0
...
...
@@ -33,6 +33,8 @@ from theano.sandbox.cuda.nnet import (
GpuCrossentropySoftmax1HotWithBiasDx
,
GpuSoftmax
,
GpuSoftmaxWithBias
)
from
theano.sandbox.cuda.elemwise
import
SupportCodeError
from
theano.scalar.basic_scipy
import
Erfinv
from
theano.sandbox.cuda.elemwise
import
ErfinvGPU
,
erfinv_gpu
from
theano.sandbox.cuda.var
import
CudaNdarrayConstant
from
theano.scan_module
import
scan_utils
,
scan_op
from
theano.tensor.blas
import
_is_real_vector
,
_is_real_matrix
...
...
@@ -177,11 +179,15 @@ def local_gpu_elemwise_0(node):
if
numpy
.
all
([
o
.
type
.
dtype
==
'float32'
for
o
in
node
.
outputs
]):
# Don't set any inplace pattern.
# gpu_inplace_elemwise_optimizer will do it later
try
:
new_op
=
GpuElemwise
(
node
.
op
.
scalar_op
)
except
SupportCodeError
:
# This happens when scalar_op requires support code
return
False
if
isinstance
(
node
.
op
.
scalar_op
,
Erfinv
):
new_op
=
GpuElemwise
(
erfinv_gpu
)
else
:
try
:
new_op
=
GpuElemwise
(
node
.
op
.
scalar_op
)
except
SupportCodeError
:
# This happens when scalar_op requires support code
return
False
# first establish that float32 can store all inputs
upcastable
=
set
([
'float32'
,
'int8'
,
'int16'
,
'uint8'
,
...
...
@@ -234,11 +240,16 @@ def local_gpu_elemwise_1(node):
elemwise_node
=
host_i
.
owner
# Don't set any inplace pattern.
# gpu_inplace_elemwise_optimizer will do it later
try
:
new_op
=
GpuElemwise
(
elemwise_node
.
op
.
scalar_op
)
except
SupportCodeError
:
# This happens when scalar_op requires support code
return
False
if
isinstance
(
node
.
op
.
scalar_op
,
Erfinv
):
new_op
=
GpuElemwise
(
erfinv_gpu
)
else
:
try
:
new_op
=
GpuElemwise
(
elemwise_node
.
op
.
scalar_op
)
except
SupportCodeError
:
# This happens when scalar_op requires support code
return
False
if
all
([
i
.
dtype
==
'float32'
for
i
in
elemwise_node
.
inputs
]):
gpu_elemwise
=
new_op
(
*
[
gpu_from_host
(
i
)
for
i
in
elemwise_node
.
inputs
])
...
...
theano/sandbox/cuda/tests/test_opt.py
浏览文件 @
f462a1c0
...
...
@@ -17,6 +17,7 @@ if cuda.cuda_available == False:
from
theano.sandbox.cuda
import
basic_ops
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.scalar.basic_scipy
import
erfinv
if
theano
.
config
.
mode
==
'FAST_COMPILE'
:
mode_with_gpu
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
including
(
'gpu'
)
...
...
@@ -368,6 +369,18 @@ def test_incsubtensor_mixed():
client
,
idx
=
packed
assert
isinstance
(
client
.
op
,
cuda
.
GpuFromHost
)
def
test_erfinvgpu
():
""" Test that local_gpu_elemwise_0 replaces Erfinv with ErfinvGPU """
x
=
tensor
.
fmatrix
()
f
=
theano
.
function
([
x
],
tensor
.
Elemwise
(
erfinv
)(
x
),
mode
=
mode_with_gpu
)
f2
=
theano
.
function
([
x
],
tensor
.
Elemwise
(
erfinv
)(
x
),
mode
=
mode_without_gpu
)
assert
isinstance
(
f
.
maker
.
fgraph
.
toposort
()[
1
]
.
op
,
cuda
.
GpuElemwise
)
assert
isinstance
(
f
.
maker
.
fgraph
.
toposort
()[
1
]
.
op
.
scalar_op
,
cuda
.
elemwise
.
ErfinvGPU
)
xv
=
numpy
.
random
.
rand
(
7
,
8
)
.
astype
(
'float32'
)
assert
numpy
.
allclose
(
f
(
xv
),
f2
(
xv
))
if
__name__
==
'__main__'
:
test_gpualloc
()
test_opt_gpujoin_onlyajoin
()
...
...
theano/scalar/basic_scipy.py
浏览文件 @
f462a1c0
...
...
@@ -78,6 +78,15 @@ erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class
Erfinv
(
UnaryScalarOp
):
"""
Implements the inverse error function.
Note: This op can still be executed on GPU, despite not having c_code. When
running on GPU, sandbox.cuda.opt.local_gpu_elemwise_[0,1] replaces this op
with sandbox.cuda.elemwise.ErfinvGPU.
(TODO) Find a C implementation of erfinv for CPU.
"""
def
impl
(
self
,
x
):
if
imported_scipy_special
:
return
scipy
.
special
.
erfinv
(
x
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论