Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
24858525
提交
24858525
authored
8月 07, 2009
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adding cross-entropy but not done yet
上级
a317f101
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
298 行增加
和
29 行删除
+298
-29
basic_ops.py
basic_ops.py
+36
-24
blas.py
blas.py
+162
-0
opt.py
opt.py
+24
-1
test_nnet.py
tests/test_nnet.py
+76
-4
没有找到文件。
basic_ops.py
浏览文件 @
24858525
...
...
@@ -138,16 +138,24 @@ class GpuElemwise(Op):
def
_logical_scalar
(
x
):
return
all
(
x
.
type
.
broadcastable
)
print
>>
sio
,
"static __global__ void kernel_
%
s(unsigned int numEls,"
%
nodename
print
>>
sio
,
"
\t
"
,
", "
.
join
(
"unsigned int log2_dim
%
i"
%
i
for
i
in
xrange
(
nd
))
print
>>
sio
,
"// Elemwise kernel for "
,
str
(
self
.
scalar_op
)
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
"// Input "
,
ipos
,
str
(
i
.
type
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
"// Output "
,
ipos
,
str
(
i
.
type
)
print
>>
sio
,
"static __global__ void kernel_
%
s(unsigned int numEls"
%
nodename
if
(
nd
):
print
>>
sio
,
"
\t
,"
,
", "
.
join
(
"unsigned int log2_dim
%
i"
%
i
for
i
in
xrange
(
nd
))
#declare inputs
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
"
\t
,"
,
", "
.
join
(
"int i
%
i_str_
%
i"
%
(
ipos
,
d
)
for
d
in
xrange
(
nd
))
print
>>
sio
,
"
\t
,"
,
"const float * i
%
i_data"
%
ipo
s
s
=
", "
.
join
([
"const float * i
%
i_data"
%
ipos
]
+
list
(
"int i
%
i_str_
%
i"
%
(
ipos
,
d
)
for
d
in
xrange
(
nd
)
))
print
>>
sio
,
"
\t
,"
,
s
#declare outputs
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
"
\t
,"
,
", "
.
join
(
"int o
%
i_str_
%
i"
%
(
ipos
,
d
)
for
d
in
xrange
(
nd
))
print
>>
sio
,
"
\t
,"
,
"float * o
%
i_data"
%
ipos
s
=
", "
.
join
([
"float * o
%
i_data"
%
ipos
]
+
list
(
"int o
%
i_str_
%
i"
%
(
ipos
,
d
)
for
d
in
xrange
(
nd
)))
print
>>
sio
,
"
\t
,"
,
s
#print >> sio, "\t,", ", ".join("int o%i_str_%i" % (ipos, d) for d in xrange(nd))
#print >> sio, "\t,", "float * o%i_data" % ipos
print
>>
sio
,
"
\t
)
\n
{"
print
>>
sio
,
" const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;"
print
>>
sio
,
" const unsigned int numThreads = blockDim.x * gridDim.x;"
...
...
@@ -183,17 +191,16 @@ class GpuElemwise(Op):
print
>>
sio
,
" ii_o
%
i_data += pos
%
i * o
%
i_str_
%
i;"
%
(
ipos
,
d
,
ipos
,
d
)
# perform the scalar operation on the input and output references
if
d
==
0
:
#TODO: What if the scalar_op needs support_code??
task_code
=
self
.
scalar_op
.
c_code
(
Apply
(
self
.
scalar_op
,
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)()
for
input
in
node
.
inputs
],
[
scalar
.
Scalar
(
dtype
=
output
.
type
.
dtype
)()
for
output
in
node
.
outputs
])
,
nodename
+
'_scalar_'
,
[(
'ii_i
%
i_value'
if
_logical_scalar
(
i
)
else
'ii_i
%
i_data[0]'
)
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
inputs
)]
,
[
'ii_o
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)]
,
sub
=
dict
(
fail
=
'return;'
))
#TODO: set a failure code somehow!!!
print
>>
sio
,
" "
,
task_code
#TODO: What if the scalar_op needs support_code??
task_code
=
self
.
scalar_op
.
c_code
(
Apply
(
self
.
scalar_op
,
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)()
for
input
in
node
.
inputs
],
[
scalar
.
Scalar
(
dtype
=
output
.
type
.
dtype
)()
for
output
in
node
.
outputs
])
,
nodename
+
'_scalar_'
,
[(
'ii_i
%
i_value'
if
_logical_scalar
(
i
)
else
'ii_i
%
i_data[0]'
)
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
inputs
)]
,
[
'ii_o
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)]
,
sub
=
dict
(
fail
=
'return;'
))
#TODO: set a failure code somehow!!!
print
>>
sio
,
" "
,
task_code
print
>>
sio
,
" }"
#TODO: insert runtime stride checks that select the best loop order either here, or in
...
...
@@ -230,11 +237,17 @@ class GpuElemwise(Op):
kernel_call_args
=
[
"numEls"
]
kernel_call_args
.
extend
(
"log2_dims[
%
i]"
%
di
for
di
in
xrange
(
nd
))
for
ipos
in
xrange
(
len
(
node
.
inputs
)):
strides
=
", "
.
join
(
"i
%
i_str[
%
i]"
%
(
ipos
,
di
)
for
di
in
xrange
(
nd
))
kernel_call_args
.
append
(
"
%
s, i
%
i_data"
%
(
strides
,
ipos
))
kernel_call_args
.
append
(
", "
.
join
([
"i
%
i_data"
%
ipos
]
+
list
(
"i
%
i_str[
%
i]"
%
(
ipos
,
di
)
for
di
in
xrange
(
nd
)))
)
#strides = ", ".join("i%i_str[%i]"%(ipos, di) for di in xrange(nd))
#kernel_call_args.append( "%s, i%i_data" % (strides, ipos))
for
ipos
in
xrange
(
len
(
node
.
outputs
)):
strides
=
", "
.
join
(
"o
%
i_str[
%
i]"
%
(
ipos
,
di
)
for
di
in
xrange
(
nd
))
kernel_call_args
.
append
(
"
%
s, o
%
i_data"
%
(
strides
,
ipos
))
kernel_call_args
.
append
(
", "
.
join
([
"o
%
i_data"
%
ipos
]
+
list
(
"o
%
i_str[
%
i]"
%
(
ipos
,
di
)
for
di
in
xrange
(
nd
)))
)
#strides = ", ".join("o%i_str[%i]"%(ipos, di) for di in xrange(nd))
#kernel_call_args.append( "%s, o%i_data" % (strides, ipos))
kernel_call_args
=
", "
.
join
(
kernel_call_args
)
# the data_pointer_increments are inserted after each recursive call
...
...
@@ -388,7 +401,7 @@ class GpuElemwise(Op):
);
//std::cerr << "calling callkernel returned
\\
n";
cudaThreadSynchronize()
;
CNDA_THREAD_SYNC
;
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
...
...
@@ -568,7 +581,7 @@ class GpuDimShuffle(Op):
}
"""
%
locals
()
if
1
:
if
0
:
# print full code to stdout
print
'--------------------------------------'
print
'C_CODE'
print
''
...
...
@@ -724,7 +737,6 @@ class GpuSubtensor(tensor.Subtensor):
#sys.stdout.flush()
#sys.exit()
class
GpuShape
(
tensor
.
Shape
):
def
make_node
(
self
,
x
):
return
Apply
(
self
,
[
x
],
[
tensor
.
lvector
()])
...
...
blas.py
浏览文件 @
24858525
...
...
@@ -3,6 +3,7 @@ from theano import tensor, scalar
import
StringIO
import
cuda_ndarray
from
.type
import
CudaNdarrayType
class
GpuDot22
(
Op
):
def
__str__
(
self
):
...
...
@@ -184,3 +185,164 @@ class GpuConv(Op):
logical_kern_shape
=
self
.
logical_kern_hw
,
kern_align
=
self
.
logical_kern_align_top
)
class
GpuCrossentropySoftmaxArgmax1HotWithBias
(
Op
):
nin
=
3
nout
=
3
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
b
,
y_idx
):
nll
=
y_idx
.
type
()
#N.B. won't work when we don't cast y_idx to float anymore
sm
=
x
.
type
()
am
=
y_idx
.
type
()
return
Apply
(
self
,
[
x
,
b
,
y_idx
],
[
nll
,
sm
,
am
])
def
c_support_code
(
self
):
return
"""
__global__ void k_xent_sm_1hot_bias(int M, int N,
const float * x_data, int xs0, int xs1,
const float * b, int bs0,
const float * y_idx_data, int y_idxs0,
float * nll_data, int nlls0,
float * sm_data, int sms0, int sms1,
float * am_data, int ams0)
{
const int row = blockIdx.x;
const float * x = x_data + xs0 * row;
const int y_idx = (int)y_idx_data[row * y_idxs0];
float * sm = sm_data + sms0 * row;
float sum = 0.0;
int row_max_j = 0;
float row_max = x[0] + b[0];
for (int j = 1; j < N; ++j)
{
float row_ij = x[j*xs1] + b[j*bs0];
//todo: store to shared memory
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max;
}
//compute the exp
for (int j = 0; j < N; ++j)
{
float row_ij = x[j*xs1] + b[j*bs0];
float sm_ij = exp(row_ij - row_max);
sum += sm_ij;
sm[j * sms1] = sm_ij;
}
float sum_inv = 1.0 / sum;
for (int j = 0; j < N; ++j)
{
sm[j * sms1] *= sum_inv;
}
if ((y_idx >= N) || (y_idx < 0))
{
//TODO: set raise an error bit in a global var?
nll_data[row*nlls0] = 0.0; // raise some suspicion at least...
}
else
{
nll_data[row*nlls0] = - x[y_idx*xs1]
- b[y_idx*bs0]
+ row_max
+ log(sum);
}
am_data[row*ams0] = row_max_j;
}
"""
def
c_code
(
self
,
node
,
nodename
,
(
x
,
b
,
y_idx
),
(
nll
,
sm
,
am
),
sub
):
classname
=
self
.
__class__
.
__name__
fail
=
sub
[
'fail'
]
sio
=
StringIO
.
StringIO
()
print
>>
sio
,
"""
if (cnda_
%(y_idx)
s->nd != 1)
{
PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor");
%(fail)
s;
}
if (cnda_
%(x)
s->nd != 2)
{
PyErr_SetString(PyExc_ValueError, "x not 2d tensor");
%(fail)
s;
}
if (cnda_
%(b)
s->nd != 1)
{
PyErr_SetString(PyExc_ValueError, "b not 1d tensor");
%(fail)
s;
}
if (CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0] != CudaNdarray_HOST_DIMS(cnda_
%(y_idx)
s)[0])
{
PyErr_SetString(PyExc_ValueError, "dimension mismatch in x,y_idx arguments");
%(fail)
s;
}
if (CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1] != CudaNdarray_HOST_DIMS(cnda_
%(b)
s)[0])
{
PyErr_SetString(PyExc_ValueError, "dimension mismatch in x,b arguments");
%(fail)
s;
}
if ((NULL == cnda_
%(nll)
s) //initial condition
|| (CudaNdarray_HOST_DIMS(cnda_
%(nll)
s)[0] != CudaNdarray_HOST_DIMS(cnda_
%(y_idx)
s)[0]))
{
Py_XDECREF(cnda_
%(nll)
s);
cnda_
%(nll)
s = (CudaNdarray*)CudaNdarray_NewDims(1, CudaNdarray_HOST_DIMS(cnda_
%(y_idx)
s));
if(!cnda_
%(nll)
s)
{
%(fail)
s;
}
}
if ((NULL == cnda_
%(sm)
s)
|| (CudaNdarray_HOST_DIMS(cnda_
%(sm)
s)[0] != CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0])
|| (CudaNdarray_HOST_DIMS(cnda_
%(sm)
s)[1] != CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1]))
{
Py_XDECREF(cnda_
%(sm)
s);
cnda_
%(sm)
s = (CudaNdarray*) CudaNdarray_NewDims(2, CudaNdarray_HOST_DIMS(cnda_
%(x)
s));
if(!cnda_
%(sm)
s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc sm output");
// no need to decref cnda_nll, the cleanup code should pick it up.
%(fail)
s;
}
}
if ((NULL == cnda_
%(am)
s)
|| (CudaNdarray_HOST_DIMS(cnda_
%(am)
s)[0] != CudaNdarray_HOST_DIMS(cnda_
%(y_idx)
s)[0]))
{
Py_XDECREF(cnda_
%(am)
s);
cnda_
%(am)
s = (CudaNdarray*) CudaNdarray_NewDims(1, CudaNdarray_HOST_DIMS(cnda_
%(y_idx)
s));
if(!cnda_
%(am)
s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc am output");
// no need to decref nll amd sm, the cleanup code should pick it up.
%(fail)
s;
}
}
{
int n_blocks = CudaNdarray_HOST_DIMS(cnda_
%(sm)
s)[0];
int n_threads = 1; //TODO: launch more threads per row and do parallel sum and max reductions.
int n_shared_bytes = 0; //n_threads * sizeof(float);
k_xent_sm_1hot_bias<<<n_blocks, n_threads, n_shared_bytes>>>(
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[0],
CudaNdarray_HOST_DIMS(cnda_
%(x)
s)[1],
CudaNdarray_DEV_DATA(cnda_
%(x)
s), CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[0], CudaNdarray_HOST_STRIDES(cnda_
%(x)
s)[1],
CudaNdarray_DEV_DATA(cnda_
%(b)
s), CudaNdarray_HOST_STRIDES(cnda_
%(b)
s)[0],
CudaNdarray_DEV_DATA(cnda_
%(y_idx)
s), CudaNdarray_HOST_STRIDES(cnda_
%(y_idx)
s)[0],
CudaNdarray_DEV_DATA(cnda_
%(nll)
s), CudaNdarray_HOST_STRIDES(cnda_
%(nll)
s)[0],
CudaNdarray_DEV_DATA(cnda_
%(sm)
s), CudaNdarray_HOST_STRIDES(cnda_
%(sm)
s)[0], CudaNdarray_HOST_STRIDES(cnda_
%(sm)
s)[1],
CudaNdarray_DEV_DATA(cnda_
%(am)
s), CudaNdarray_HOST_STRIDES(cnda_
%(am)
s)[0]);
CNDA_THREAD_SYNC;
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%(classname)
s
%(nodename)
s:
%%
s.
\\
n", cudaGetErrorString(err));
// no need to decref output vars the cleanup code should pick them up.
%(fail)
s;
}
}
"""
%
locals
()
return
sio
.
getvalue
()
opt.py
浏览文件 @
24858525
...
...
@@ -3,7 +3,7 @@ from theano import tensor, scalar, compile
from
theano.gof
import
local_optimizer
,
EquilibriumDB
,
SequenceDB
from
.basic_ops
import
*
from
.blas
import
gpu_dot22
,
gpu_gemm
,
GpuConv
from
.blas
import
gpu_dot22
,
gpu_gemm
,
GpuConv
,
GpuCrossentropySoftmaxArgmax1HotWithBias
from
theano.compile
import
optdb
#optdb.print_summary() # this shows what is currently registered (in a so-far crude way...)
...
...
@@ -229,3 +229,26 @@ def local_gpu_shape(node):
return
[
gpu_shape
(
gpu_x
)]
return
False
def
cast
(
x
,
dtype
):
stype
=
theano
.
scalar
.
Scalar
(
dtype
)
cast_op
=
theano
.
tensor
.
Elemwise
(
scalar
.
Identity
(
scalar
.
specific_out
(
stype
)))
return
cast_op
(
x
)
import
theano.tensor.nnet
@register_opt
()
@local_optimizer
([])
def
local_gpu_crossentorpy_softmax_argmax_1hot_with_bias
(
node
):
if
isinstance
(
node
.
op
,
tensor
.
nnet
.
CrossentropySoftmaxArgmax1HotWithBias
):
x
,
b
,
y
=
node
.
inputs
if
x
.
owner
and
x
.
owner
.
op
==
host_from_gpu
:
gpu_x
,
=
x
.
owner
.
inputs
gpu_nll
,
gpu_sm
,
gpu_am
=
GpuCrossentropySoftmaxArgmax1HotWithBias
()(
gpu_x
,
gpu_from_host
(
b
),
gpu_from_host
(
cast
(
y
,
'float32'
)))
am_dtype
=
node
.
outputs
[
2
]
.
type
.
dtype
return
[
host_from_gpu
(
gpu_nll
),
host_from_gpu
(
gpu_sm
),
cast
(
host_from_gpu
(
gpu_am
),
am_dtype
)]
return
False
tests/test_nnet.py
浏览文件 @
24858525
...
...
@@ -3,6 +3,7 @@ import theano, theano.sandbox.conv
from
theano.compile.sandbox.sharedvalue
import
shared
from
theano.compile.sandbox.pfunc
import
pfunc
from
theano
import
tensor
import
theano.tensor.nnet
import
numpy
...
...
@@ -120,8 +121,9 @@ def test_conv_nnet1():
rval_gpu
=
run_conv_nnet1
(
tcn
.
shared_constructor
)
assert
numpy
.
allclose
(
rval_cpu
,
rval_gpu
,
rtol
=
1e-4
,
atol
=
1e-6
)
def
run_conv_nnet2
(
shared_fn
):
n_batch
=
16
def
run_conv_nnet2
(
shared_fn
):
# pretend we are training LeNet for MNIST
n_batch
=
60
shape_img
=
(
n_batch
,
1
,
32
,
32
)
n_kern
=
20
...
...
@@ -168,12 +170,15 @@ def run_conv_nnet2(shared_fn):
print
i
,
n
xval
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
*
shape_img
),
dtype
=
'float32'
)
yval
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
n_batch
,
n_out
),
dtype
=
'floa
t32'
)
yval
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
n_batch
,
n_out
),
dtype
=
'in
t32'
)
lr
=
numpy
.
asarray
(
0.01
,
dtype
=
'float32'
)
for
i
in
xrange
(
10
):
rval
=
train
(
xval
,
yval
,
lr
)
mode
.
print_summary
()
try
:
mode
.
print_summary
()
except
:
pass
return
rval
def
test_conv_nnet2
():
...
...
@@ -183,3 +188,70 @@ def test_conv_nnet2():
rval_gpu
=
run_conv_nnet2
(
tcn
.
shared_constructor
)
assert
numpy
.
allclose
(
rval_cpu
,
rval_gpu
,
rtol
=
1e-4
,
atol
=
1e-6
)
def
run_conv_nnet2_classif
(
shared_fn
):
# pretend we are training LeNet for MNIST
n_batch
=
60
shape_img
=
(
n_batch
,
1
,
32
,
32
)
n_kern
=
20
shape_kern
=
(
n_kern
,
1
,
5
,
5
)
n_kern1
=
30
shape_kern1
=
(
n_kern1
,
n_kern
,
5
,
5
)
logical_hid_shape
=
tcn
.
blas
.
GpuConv
.
logical_output_shape_2d
((
32
,
32
),
(
5
,
5
),
'valid'
)
logical_hid_shape1
=
tcn
.
blas
.
GpuConv
.
logical_output_shape_2d
((
logical_hid_shape
[
0
]
/
2
,
logical_hid_shape
[
1
]
/
2
),
(
5
,
5
),
'valid'
)
n_hid
=
n_kern1
*
logical_hid_shape1
[
0
]
*
logical_hid_shape1
[
1
]
n_out
=
10
w0
=
shared_fn
(
numpy
.
asarray
(
0.01
*
(
numpy
.
random
.
rand
(
*
shape_kern
)
-
0.5
),
dtype
=
'float32'
),
'w0'
)
b0
=
shared_fn
(
numpy
.
asarray
(
numpy
.
zeros
((
n_kern
,
1
,
1
)),
dtype
=
'float32'
),
'b0'
)
w1
=
shared_fn
(
numpy
.
asarray
(
0.01
*
(
numpy
.
random
.
rand
(
*
shape_kern1
)
-
0.5
),
dtype
=
'float32'
),
'w1'
)
b1
=
shared_fn
(
numpy
.
asarray
(
numpy
.
zeros
((
n_kern1
,
1
,
1
)),
dtype
=
'float32'
),
'b1'
)
v
=
shared_fn
(
numpy
.
asarray
(
numpy
.
zeros
((
n_hid
,
n_out
)),
dtype
=
'float32'
),
'c'
)
c
=
shared_fn
(
numpy
.
asarray
(
numpy
.
zeros
(
n_out
),
dtype
=
'float32'
),
'c'
)
x
=
tensor
.
Tensor
(
dtype
=
'float32'
,
broadcastable
=
(
0
,
0
,
0
,
0
))(
'x'
)
y
=
tensor
.
fmatrix
(
'y'
)
lr
=
tensor
.
fscalar
(
'lr'
)
conv_op
=
theano
.
sandbox
.
conv
.
ConvOp
(
shape_img
[
2
:],
shape_kern
[
2
:],
n_kern
,
n_batch
,
1
,
1
)
conv_op1
=
theano
.
sandbox
.
conv
.
ConvOp
((
n_kern
,
logical_hid_shape
[
0
]
/
2
,
logical_hid_shape
[
1
]
/
2
),
shape_kern1
[
2
:],
n_kern1
,
n_batch
,
1
,
1
)
hid
=
tensor
.
tanh
(
conv_op
(
x
,
w0
)
+
b0
)
hid1
=
tensor
.
tanh
(
conv_op1
(
hid
[:,:,::
2
,::
2
],
w1
)
+
b1
)
hid_flat
=
hid1
.
reshape
((
n_batch
,
n_hid
))
out
=
tensor
.
tanh
(
tensor
.
dot
(
hid_flat
,
v
)
+
c
)
loss
=
tensor
.
sum
(
0.5
*
(
out
-
y
)
**
2
*
lr
)
print
'loss type'
,
loss
.
type
params
=
[
w0
,
b0
,
w1
,
b1
,
v
,
c
]
gparams
=
tensor
.
grad
(
loss
,
params
)
mode
=
theano
.
compile
.
ProfileMode
()
print
'building pfunc ...'
train
=
pfunc
([
x
,
y
,
lr
],
[
loss
],
mode
=
mode
,
updates
=
[(
p
,
p
-
g
)
for
p
,
g
in
zip
(
params
,
gparams
)])
for
i
,
n
in
enumerate
(
train
.
maker
.
env
.
toposort
()):
print
i
,
n
xval
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
*
shape_img
),
dtype
=
'float32'
)
yval
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
n_batch
,
n_out
),
dtype
=
'int32'
)
lr
=
numpy
.
asarray
(
0.01
,
dtype
=
'float32'
)
for
i
in
xrange
(
10
):
rval
=
train
(
xval
,
yval
,
lr
)
try
:
mode
.
print_summary
()
except
:
pass
return
rval
def
test_conv_nnet2_classif
():
numpy
.
random
.
seed
(
23456
)
rval_cpu
=
run_conv_nnet2
(
shared
)
numpy
.
random
.
seed
(
23456
)
rval_gpu
=
run_conv_nnet2
(
tcn
.
shared_constructor
)
assert
numpy
.
allclose
(
rval_cpu
,
rval_gpu
,
rtol
=
1e-4
,
atol
=
1e-6
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论