Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3ca77162
提交
3ca77162
authored
9月 10, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1522 from nouiz/small_stuff
Small stuff
上级
999aeb92
b0350903
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
111 行增加
和
72 行删除
+111
-72
do_nightly_build
theano/misc/do_nightly_build
+4
-3
check_whitespace.py
theano/misc/hooks/check_whitespace.py
+1
-1
multinomial.py
theano/sandbox/multinomial.py
+29
-15
test_multinomial.py
theano/sandbox/test_multinomial.py
+53
-31
opt.py
theano/tensor/opt.py
+4
-3
test_subtensor.py
theano/tensor/tests/test_subtensor.py
+1
-1
diverse_tests.py
theano/tests/diverse_tests.py
+19
-18
没有找到文件。
theano/misc/do_nightly_build
浏览文件 @
3ca77162
...
...
@@ -54,9 +54,10 @@ if [ "$RELEASE" ]; then
echo
fi
echo
"Executing tests with mode=FAST_COMPILE with --batch=1000"
echo
"THEANO_FLAGS=
${
FLAGS
}
,mode=FAST_COMPILE
${
NOSETESTS
}
--batch=1000
${
ARGS
}
"
THEANO_FLAGS
=
${
FLAGS
}
,mode
=
FAST_COMPILE
${
NOSETESTS
}
--batch
=
1000
${
ARGS
}
# with --batch=1000" # The buildbot freeze sometimes when collecting the tests to run
echo
"Executing tests with mode=FAST_COMPILE"
echo
"THEANO_FLAGS=
${
FLAGS
}
,mode=FAST_COMPILE
${
NOSETESTS
}
${
ARGS
}
"
THEANO_FLAGS
=
${
FLAGS
}
,mode
=
FAST_COMPILE
${
NOSETESTS
}
${
ARGS
}
echo
"Number of elements in the compiledir:"
ls
${
COMPILEDIR
}
|wc
-l
echo
...
...
theano/misc/hooks/check_whitespace.py
浏览文件 @
3ca77162
...
...
@@ -12,7 +12,7 @@ import tokenize
import
argparse
import
reindent
from
theano.compat.six
.StringIO
import
StringIO
from
theano.compat.six
import
StringIO
SKIP_WHITESPACE_CHECK_FILENAME
=
".hg/skip_whitespace_check"
...
...
theano/sandbox/multinomial.py
浏览文件 @
3ca77162
...
...
@@ -12,22 +12,28 @@ if cuda_available:
from
theano.sandbox.cuda.basic_ops
import
host_from_gpu
,
gpu_from_host
from
theano.sandbox.cuda.opt
import
register_opt
class
MultinomialFromUniform
(
Op
):
'''Converts samples from a uniform into sample from a multinomial.'''
def
__init__
(
self
,
odtype
):
self
.
odtype
=
odtype
self
.
odtype
=
odtype
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
and
self
.
odtype
==
other
.
odtype
return
type
(
self
)
==
type
(
other
)
and
self
.
odtype
==
other
.
odtype
def
__hash__
(
self
):
return
hash
((
type
(
self
),
self
.
odtype
))
def
__str__
(
self
):
return
'
%
s{
%
s}'
%
(
self
.
__class__
.
__name__
,
self
.
odtype
)
return
'
%
s{
%
s}'
%
(
self
.
__class__
.
__name__
,
self
.
odtype
)
def
__setstate__
(
self
,
dct
):
self
.
__dict__
.
update
(
dct
)
try
:
self
.
odtype
except
AttributeError
:
self
.
odtype
=
'auto'
self
.
odtype
=
'auto'
def
make_node
(
self
,
pvals
,
unis
):
pvals
=
T
.
as_tensor_variable
(
pvals
)
unis
=
T
.
as_tensor_variable
(
unis
)
...
...
@@ -35,11 +41,12 @@ class MultinomialFromUniform(Op):
raise
NotImplementedError
(
'pvals ndim should be 2'
,
pvals
.
ndim
)
if
unis
.
ndim
!=
1
:
raise
NotImplementedError
(
'unis ndim should be 1'
,
unis
.
ndim
)
if
self
.
odtype
==
'auto'
:
if
self
.
odtype
==
'auto'
:
odtype
=
pvals
.
dtype
else
:
odtype
=
self
.
odtype
return
Apply
(
self
,
[
pvals
,
unis
],
[
T
.
matrix
(
dtype
=
odtype
)])
out
=
T
.
tensor
(
dtype
=
odtype
,
broadcastable
=
pvals
.
type
.
broadcastable
)
return
Apply
(
self
,
[
pvals
,
unis
],
[
out
])
def
grad
(
self
,
ins
,
outgrads
):
pvals
,
unis
=
ins
...
...
@@ -121,6 +128,7 @@ class MultinomialFromUniform(Op):
}
} // END NESTED SCOPE
"""
%
locals
()
def
perform
(
self
,
node
,
ins
,
outs
):
(
pvals
,
unis
)
=
ins
(
z
,)
=
outs
...
...
@@ -165,15 +173,17 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
raise
TypeError
(
'pvals must be cudandarray'
,
pvals
)
if
not
isinstance
(
unis
.
type
,
CudaNdarrayType
):
raise
TypeError
(
'unis must be cudandarray'
,
unis
)
if
self
.
odtype
==
'auto'
:
if
self
.
odtype
==
'auto'
:
odtype
=
pvals
.
dtype
else
:
odtype
=
self
.
odtype
if
odtype
!=
pvals
.
dtype
:
raise
NotImplementedError
(
'GpuMultinomialFromUniform works only if '
'self.odtype == pvals.dtype'
,
odtype
,
pvals
.
dtype
)
return
Apply
(
self
,
[
pvals
,
unis
],
[
pvals
.
type
()])
'GpuMultinomialFromUniform works only if '
'self.odtype == pvals.dtype'
,
odtype
,
pvals
.
dtype
)
br
=
(
pvals
.
broadcastable
[
1
],
pvals
.
broadcastable
[
0
])
out
=
CudaNdarrayType
(
broadcastable
=
br
)()
return
Apply
(
self
,
[
pvals
,
unis
],
[
out
])
def
perform
(
self
,
node
,
ins
,
outs
):
#The perform from parent don't work with CudaNdarray. We
...
...
@@ -226,7 +236,6 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
"""
%
locals
()
def
c_code
(
self
,
node
,
name
,
ins
,
outs
,
sub
):
(
pvals
,
unis
)
=
ins
(
z
,)
=
outs
...
...
@@ -327,25 +336,30 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
} // END NESTED SCOPE
"""
%
locals
()
@local_optimizer
()
def
local_gpu_multinomial
(
node
):
if
type
(
node
.
op
)
is
MultinomialFromUniform
:
p
,
u
=
node
.
inputs
m
,
=
node
.
outputs
if
(
p
.
dtype
==
u
.
dtype
==
m
.
dtype
==
'float32'
and
any
([
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
sandbox
.
cuda
.
HostFromGpu
)
any
([
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
theano
.
sandbox
.
cuda
.
HostFromGpu
)
for
i
in
node
.
inputs
])):
gpu_op
=
GpuMultinomialFromUniform
(
node
.
op
.
odtype
)
return
[
host_from_gpu
(
gpu_op
(
*
[
gpu_from_host
(
i
)
for
i
in
node
.
inputs
]))
.
T
]
return
[
host_from_gpu
(
gpu_op
(
*
[
gpu_from_host
(
i
)
for
i
in
node
.
inputs
]))
.
T
]
if
(
isinstance
(
node
.
op
,
theano
.
sandbox
.
cuda
.
GpuFromHost
)
and
node
.
inputs
[
0
]
.
owner
and
type
(
node
.
inputs
[
0
]
.
owner
.
op
)
is
MultinomialFromUniform
):
node
.
inputs
[
0
]
.
owner
and
type
(
node
.
inputs
[
0
]
.
owner
.
op
)
is
MultinomialFromUniform
):
multi
=
node
.
inputs
[
0
]
.
owner
p
,
u
=
multi
.
inputs
m
,
=
multi
.
outputs
if
(
p
.
dtype
==
u
.
dtype
==
m
.
dtype
==
'float32'
):
gpu_op
=
GpuMultinomialFromUniform
(
multi
.
op
.
odtype
)
ret
=
gpu_op
(
*
[
gpu_from_host
(
i
)
for
i
in
multi
.
inputs
])
.
T
# The dimshuffle is on the cpu, but will be moved to the gpu by an opt.
# The dimshuffle is on the cpu, but will be moved to the
# gpu by an opt.
return
[
gpu_from_host
(
ret
)]
if
cuda_available
:
...
...
theano/sandbox/test_multinomial.py
浏览文件 @
3ca77162
...
...
@@ -9,15 +9,19 @@ from theano.compile.mode import get_default_mode, predefined_linkers
from
theano.gof.python25
import
any
import
theano.sandbox.cuda
as
cuda
def
get_mode
(
gpu
):
mode
=
get_default_mode
()
mode
=
copy
.
copy
(
mode
)
if
gpu
:
mode
=
mode
.
including
(
'gpu'
,
'gpu_local_optimizations'
,
'local_cut_gpu_host_gpu'
,
'local_gpu_multinomial'
)
mode
=
mode
.
including
(
'gpu'
,
'gpu_local_optimizations'
,
'local_cut_gpu_host_gpu'
,
'local_gpu_multinomial'
)
if
isinstance
(
mode
.
linker
,
theano
.
gof
.
PerformLinker
):
mode
.
linker
=
predefined_linkers
[
'c|py'
]
return
mode
def
run_with_c
(
f
,
gpu
=
False
):
mode
=
get_mode
(
gpu
)
f
(
mode
,
gpu
)
...
...
@@ -30,52 +34,54 @@ def test_multinomial_0():
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
def
body
(
mode
,
gpu
):
#the m*2 allows the multinomial to reuse output
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
,
mode
=
mode
)
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
,
mode
=
mode
)
if
gpu
:
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
# test that both first and second samples can be drawn
assert
numpy
.
allclose
(
f
([[
1
,
0
],
[
0
,
1
]],
[
.
1
,
.
1
]),
[[
2
,
0
],
[
0
,
2
]])
assert
numpy
.
allclose
(
f
([[
1
,
0
],
[
0
,
1
]],
[
.
1
,
.
1
]),
[[
2
,
0
],
[
0
,
2
]])
# test that both second labels can be drawn
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
31
,
.
31
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
],
[
0
,
2
]]),
r
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
31
,
.
31
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
],
[
0
,
2
]]),
r
# test that both first labels can be drawn
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
21
,
.
21
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
],
[
2
,
0
]]),
r
r
=
f
([[
.
2
,
.
8
],
[
.
3
,
.
7
]],
[
.
21
,
.
21
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
],
[
2
,
0
]]),
r
#change the size to make sure output gets reallocated ok
# and also make sure that the GPU version doesn't screw up the
# transposed-ness
r
=
f
([[
.
2
,
.
8
]
],
[
.
25
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
]]),
r
r
=
f
([[
.
2
,
.
8
]
],
[
.
25
])
assert
numpy
.
allclose
(
r
,
[[
0
,
2
]]),
r
run_with_c
(
body
)
if
cuda
.
cuda_available
:
run_with_c
(
body
,
True
)
#TODO: check a bigger example (make sure blocking on GPU is handled correctly)
def
test_multinomial_large
():
# DEBUG_MODE will test this on GPU
def
body
(
mode
,
gpu
):
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
,
mode
=
mode
)
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
f
=
function
([
p
,
u
],
m
*
2
,
allow_input_downcast
=
True
,
mode
=
mode
)
if
gpu
:
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
pval
=
numpy
.
arange
(
10000
*
4
,
dtype
=
'float32'
)
.
reshape
((
10000
,
4
))
+
0.1
pval
=
pval
/
pval
.
sum
(
axis
=
1
)[:,
None
]
uval
=
numpy
.
ones_like
(
pval
[:,
0
])
*
0.5
mval
=
f
(
pval
,
uval
)
pval
=
pval
/
pval
.
sum
(
axis
=
1
)[:,
None
]
uval
=
numpy
.
ones_like
(
pval
[:,
0
])
*
0.5
mval
=
f
(
pval
,
uval
)
assert
mval
.
shape
==
pval
.
shape
if
config
.
cast_policy
==
'custom'
:
...
...
@@ -88,7 +94,7 @@ def test_multinomial_large():
raise
NotImplementedError
(
config
.
cast_policy
)
assert
numpy
.
allclose
(
mval
.
sum
(
axis
=
1
),
2
)
asdf
=
numpy
.
asarray
([
0
,
0
,
2
,
0
])
+
0
*
pval
assert
numpy
.
allclose
(
mval
,
asdf
)
#
broadcast over all rows
assert
numpy
.
allclose
(
mval
,
asdf
)
#
broadcast over all rows
run_with_c
(
body
)
if
cuda
.
cuda_available
:
run_with_c
(
body
,
True
)
...
...
@@ -97,36 +103,52 @@ def test_multinomial_large():
def
test_multinomial_dtypes
():
p
=
tensor
.
dmatrix
()
u
=
tensor
.
dvector
()
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
assert
m
.
dtype
==
'float64'
,
m
.
dtype
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
assert
m
.
dtype
==
'float32'
,
m
.
dtype
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
MultinomialFromUniform
(
'float64'
)(
p
,
u
)
m
=
multinomial
.
MultinomialFromUniform
(
'float64'
)(
p
,
u
)
assert
m
.
dtype
==
'float64'
,
m
.
dtype
def
test_gpu_opt
():
if
not
cuda
.
cuda_available
:
# Skip test if cuda_ndarray is not available.
from
nose.plugins.skip
import
SkipTest
raise
SkipTest
(
'Optional package cuda not available'
)
# We test the case where we put the op on the gpu when the output is moved to the gpu.
# We test the case where we put the op on the gpu when the output
# is moved to the gpu.
p
=
tensor
.
fmatrix
()
u
=
tensor
.
fvector
()
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
p
,
u
)
assert
m
.
dtype
==
'float32'
,
m
.
dtype
m_gpu
=
cuda
.
gpu_from_host
(
m
)
f
=
function
([
p
,
u
],
m_gpu
,
allow_input_downcast
=
True
,
mode
=
get_mode
(
True
))
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
f
=
function
([
p
,
u
],
m_gpu
,
allow_input_downcast
=
True
,
mode
=
get_mode
(
True
))
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
pval
=
numpy
.
arange
(
10000
*
4
,
dtype
=
'float32'
)
.
reshape
((
10000
,
4
))
+
0.1
pval
=
pval
/
pval
.
sum
(
axis
=
1
)[:,
None
]
uval
=
numpy
.
ones_like
(
pval
[:,
0
])
*
0.5
mval
=
f
(
pval
,
uval
)
pval
=
pval
/
pval
.
sum
(
axis
=
1
)[:,
None
]
uval
=
numpy
.
ones_like
(
pval
[:,
0
])
*
0.5
mval
=
f
(
pval
,
uval
)
# Test with a row, it was failing in the past.
r
=
tensor
.
frow
()
m
=
multinomial
.
MultinomialFromUniform
(
'auto'
)(
r
,
u
)
assert
m
.
dtype
==
'float32'
,
m
.
dtype
m_gpu
=
cuda
.
gpu_from_host
(
m
)
f
=
function
([
r
,
u
],
m_gpu
,
allow_input_downcast
=
True
,
mode
=
get_mode
(
True
))
assert
any
([
type
(
node
.
op
)
is
multinomial
.
GpuMultinomialFromUniform
for
node
in
f
.
maker
.
fgraph
.
toposort
()])
pval
=
numpy
.
arange
(
1
*
4
,
dtype
=
'float32'
)
.
reshape
((
1
,
4
))
+
0.1
pval
=
pval
/
pval
.
sum
(
axis
=
1
)[:,
None
]
uval
=
numpy
.
ones_like
(
pval
[:,
0
])
*
0.5
mval2
=
f
(
pval
,
uval
)
theano/tensor/opt.py
浏览文件 @
3ca77162
...
...
@@ -4656,9 +4656,10 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
else
:
s
=
scalar
.
Scalar
(
i
.
dtype
)
.
make_variable
()
try
:
v
=
gof
.
op
.
get_test_value
(
i
)
if
v
.
size
>
0
:
s
.
tag
.
test_value
=
gof
.
op
.
get_test_value
(
i
)
.
flatten
()[
0
]
if
theano
.
config
.
compute_test_value
!=
'off'
:
v
=
gof
.
op
.
get_test_value
(
i
)
if
v
.
size
>
0
:
s
.
tag
.
test_value
=
v
.
flatten
()[
0
]
except
AttributeError
:
pass
...
...
theano/tensor/tests/test_subtensor.py
浏览文件 @
3ca77162
...
...
@@ -1146,7 +1146,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
subt
=
self
.
m
[
self
.
ix1
,
self
.
ix12
]
a
=
inc_subtensor
(
subt
,
subt
)
typ
=
TensorType
(
self
.
m
.
type
.
dtype
,
self
.
ix2
.
type
.
broadcastable
)
typ
=
tensor
.
TensorType
(
self
.
m
.
type
.
dtype
,
self
.
ix2
.
type
.
broadcastable
)
assert
a
.
type
==
typ
,
(
a
.
type
,
typ
)
f
=
theano
.
function
([
self
.
m
,
self
.
ix1
,
self
.
ix12
],
a
,
allow_input_downcast
=
True
)
...
...
theano/tests/diverse_tests.py
浏览文件 @
3ca77162
from
nose.plugins.skip
import
SkipTest
import
unittest
import
theano
import
numpy
import
random
import
numpy.random
from
theano.tests
import
unittest_tools
as
utt
import
theano
from
theano.tests
import
unittest_tools
as
utt
'''
Different tests that are not connected to any particular Op, or functionality of
Theano. Here will go for example code that we will publish in papers, that we
should ensure that it will remain operational
Different tests that are not connected to any particular Op, or
functionality of Theano. Here will go for example code that we will
publish in papers, that we should ensure that it will remain
operational
'''
class
T_scipy
(
unittest
.
TestCase
):
def
setUp
(
self
):
utt
.
seed_rng
()
self
.
orig_floatX
=
theano
.
config
.
floatX
def
tearDown
(
self
):
theano
.
config
.
floatX
=
self
.
orig_floatX
def
test_scipy_paper_example1
(
self
):
a
=
theano
.
tensor
.
vector
(
'a'
)
# declare variable
b
=
a
+
a
**
10
# build expression
f
=
theano
.
function
([
a
],
b
)
# compile function
assert
numpy
.
all
(
f
([
0
,
1
,
2
])
==
numpy
.
array
([
0
,
2
,
1026
]))
a
=
theano
.
tensor
.
vector
(
'a'
)
# declare variable
b
=
a
+
a
**
10
# build expression
f
=
theano
.
function
([
a
],
b
)
# compile function
assert
numpy
.
all
(
f
([
0
,
1
,
2
])
==
numpy
.
array
([
0
,
2
,
1026
]))
def
test_scipy_paper_example2
(
self
):
''' This just sees if things compile well and if they run '''
...
...
@@ -34,7 +36,7 @@ class T_scipy(unittest.TestCase):
shared
=
theano
.
shared
function
=
theano
.
function
rng
=
numpy
.
random
theano
.
config
.
floatX
=
'float64'
theano
.
config
.
floatX
=
'float64'
#
# ACTUAL SCRIPT FROM PAPER
...
...
@@ -49,18 +51,18 @@ class T_scipy(unittest.TestCase):
xent
=
-
y
*
T
.
log
(
p_1
)
-
(
1
-
y
)
*
T
.
log
(
1
-
p_1
)
prediction
=
p_1
>
0.5
cost
=
xent
.
mean
()
+
0.01
*
(
w
**
2
)
.
sum
()
gw
,
gb
=
T
.
grad
(
cost
,
[
w
,
b
])
gw
,
gb
=
T
.
grad
(
cost
,
[
w
,
b
])
# Compile expressions to functions
train
=
function
(
inputs
=
[
x
,
y
],
inputs
=
[
x
,
y
],
outputs
=
[
prediction
,
xent
],
updates
=
[(
w
,
w
-
0.1
*
gw
),
(
b
,
b
-
0.1
*
gb
)])
predict
=
function
(
inputs
=
[
x
],
outputs
=
prediction
)
N
=
4
feats
=
100
D
=
(
rng
.
randn
(
N
,
feats
),
rng
.
randint
(
size
=
4
,
low
=
0
,
high
=
2
))
D
=
(
rng
.
randn
(
N
,
feats
),
rng
.
randint
(
size
=
4
,
low
=
0
,
high
=
2
))
training_steps
=
10
for
i
in
range
(
training_steps
):
pred
,
err
=
train
(
D
[
0
],
D
[
1
])
...
...
@@ -68,4 +70,3 @@ class T_scipy(unittest.TestCase):
if
__name__
==
'__main__'
:
unittest
.
main
()
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论