Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bf752d81
提交
bf752d81
authored
11月 17, 2020
作者:
junpenglao
提交者:
Brandon T. Willard
11月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change tt.config to theano.config in test_jax
上级
0ea34350
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
65 行增加
和
54 行删除
+65
-54
test_jax.py
tests/sandbox/test_jax.py
+65
-54
没有找到文件。
tests/sandbox/test_jax.py
浏览文件 @
bf752d81
...
...
@@ -111,7 +111,7 @@ def test_jax_Alloc():
x
=
tt
.
alloc
(
a
,
20
,
10
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
ones
(
10
,
dtype
=
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
ones
(
10
,
dtype
=
t
heano
.
config
.
floatX
)])
def
test_jax_compile_ops
():
...
...
@@ -182,8 +182,8 @@ def test_jax_basic():
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
np
.
tile
(
np
.
arange
(
10
),
(
10
,
1
))
.
astype
(
t
t
.
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
,
20
),
(
10
,
1
))
.
astype
(
t
t
.
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
),
(
10
,
1
))
.
astype
(
t
heano
.
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
,
20
),
(
10
,
1
))
.
astype
(
t
heano
.
config
.
floatX
),
]
(
jax_res
,)
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
...
...
@@ -201,43 +201,49 @@ def test_jax_basic():
out
=
tt
.
diagonal
(
x
,
0
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
t
t
.
config
.
floatX
)]
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
t
heano
.
config
.
floatX
)]
)
out
=
tt
.
slinalg
.
cholesky
(
x
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
out_fg
,
[(
np
.
eye
(
10
)
+
np
.
random
.
randn
(
10
,
10
)
*
0.01
)
.
astype
(
tt
.
config
.
floatX
)]
out_fg
,
[(
np
.
eye
(
10
)
+
np
.
random
.
randn
(
10
,
10
)
*
0.01
)
.
astype
(
theano
.
config
.
floatX
)],
)
# not sure why this isn't working yet with lower=False
out
=
tt
.
slinalg
.
Cholesky
(
lower
=
False
)(
x
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
out_fg
,
[(
np
.
eye
(
10
)
+
np
.
random
.
randn
(
10
,
10
)
*
0.01
)
.
astype
(
tt
.
config
.
floatX
)]
out_fg
,
[(
np
.
eye
(
10
)
+
np
.
random
.
randn
(
10
,
10
)
*
0.01
)
.
astype
(
theano
.
config
.
floatX
)],
)
out
=
tt
.
slinalg
.
solve
(
x
,
b
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
,
b
],
[
out
])
compare_jax_and_py
(
out_fg
,
[
np
.
eye
(
10
)
.
astype
(
tt
.
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
tt
.
config
.
floatX
)],
[
np
.
eye
(
10
)
.
astype
(
theano
.
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
theano
.
config
.
floatX
),
],
)
out
=
tt
.
nlinalg
.
alloc_diag
(
b
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
b
],
[
out
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
)
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
)
.
astype
(
t
heano
.
config
.
floatX
)])
out
=
tt
.
nlinalg
.
det
(
x
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
t
t
.
config
.
floatX
)]
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
t
heano
.
config
.
floatX
)]
)
out
=
tt
.
nlinalg
.
matrix_inverse
(
x
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
out_fg
,
[(
np
.
eye
(
10
)
+
np
.
random
.
randn
(
10
,
10
)
*
0.01
)
.
astype
(
tt
.
config
.
floatX
)]
out_fg
,
[(
np
.
eye
(
10
)
+
np
.
random
.
randn
(
10
,
10
)
*
0.01
)
.
astype
(
theano
.
config
.
floatX
)],
)
...
...
@@ -261,25 +267,25 @@ def test_jax_basic_multiout():
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
outs
)
def
assert_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
.
astype
(
t
t
.
config
.
floatX
),
y
,
rtol
=
1e-3
)
np
.
testing
.
assert_allclose
(
x
.
astype
(
t
heano
.
config
.
floatX
),
y
,
rtol
=
1e-3
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
t
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
heano
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
tt
.
nlinalg
.
eigh
(
x
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
t
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
heano
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
tt
.
nlinalg
.
qr
(
x
,
mode
=
"full"
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
t
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
heano
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
tt
.
nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
t
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
heano
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
tt
.
nlinalg
.
svd
(
x
)
out_fg
=
theano
.
gof
.
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
t
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
t
heano
.
config
.
floatX
)],
assert_fn
=
assert_fn
)
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
...
...
@@ -357,10 +363,11 @@ def test_jax_scan_multiple_output():
)
s0
,
e0
,
i0
=
100
,
50
,
25
logp_c0
=
np
.
array
(
0.0
)
.
astype
(
tt
.
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
)
.
astype
(
tt
.
config
.
floatX
)
logp_c0
=
np
.
array
(
0.0
,
dtype
=
theano
.
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
,
dtype
=
theano
.
config
.
floatX
)
beta_val
,
gamma_val
,
delta_val
=
[
np
.
array
(
val
)
.
astype
(
tt
.
config
.
floatX
)
for
val
in
[
0.277792
,
0.135330
,
0.108753
]
np
.
array
(
val
,
dtype
=
theano
.
config
.
floatX
)
for
val
in
[
0.277792
,
0.135330
,
0.108753
]
]
C
=
np
.
array
([
3
,
5
,
8
,
13
,
21
,
26
,
10
,
3
],
dtype
=
np
.
int32
)
D
=
np
.
array
([
1
,
2
,
3
,
7
,
9
,
11
,
5
,
1
],
dtype
=
np
.
int32
)
...
...
@@ -396,7 +403,7 @@ def test_jax_scan_tap_output():
outputs_info
=
[
{
"initial"
:
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
1.3
,
0.0
]
.
astype
(
t
t
.
config
.
floatX
)
np
.
r_
[
-
1.0
,
1.3
,
0.0
]
.
astype
(
t
heano
.
config
.
floatX
)
),
"taps"
:
[
-
1
,
-
3
],
},
...
...
@@ -410,7 +417,7 @@ def test_jax_scan_tap_output():
out_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
y_scan_tt
])
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
t
t
.
config
.
floatX
)]
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
t
heano
.
config
.
floatX
)]
compare_jax_and_py
(
out_fg
,
test_input_vals
)
...
...
@@ -457,16 +464,16 @@ def test_jax_Subtensors():
def
test_jax_IncSubtensor
():
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
3
,
4
,
5
))
.
astype
(
t
t
.
config
.
floatX
)
x_tt
=
tt
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
t
t
.
config
.
floatX
)
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
(
3
,
4
,
5
))
.
astype
(
t
heano
.
config
.
floatX
)
x_tt
=
tt
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
t
heano
.
config
.
floatX
)
# "Set" basic indices
st_tt
=
tt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
t
t
.
config
.
floatX
))
st_tt
=
tt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
t
heano
.
config
.
floatX
))
out_tt
=
tt
.
set_subtensor
(
x_tt
[
1
,
2
,
3
],
st_tt
)
out_fg
=
theano
.
gof
.
FunctionGraph
([],
[
out_tt
])
compare_jax_and_py
(
out_fg
,
[])
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
t
.
config
.
floatX
))
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
heano
.
config
.
floatX
))
out_tt
=
tt
.
set_subtensor
(
x_tt
[:
2
,
0
,
0
],
st_tt
)
out_fg
=
theano
.
gof
.
FunctionGraph
([],
[
out_tt
])
compare_jax_and_py
(
out_fg
,
[])
...
...
@@ -476,7 +483,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py
(
out_fg
,
[])
# "Set" advanced indices
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
t
.
config
.
floatX
))
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
heano
.
config
.
floatX
))
out_tt
=
tt
.
set_subtensor
(
x_tt
[[
0
,
2
],
0
,
0
],
st_tt
)
out_fg
=
theano
.
gof
.
FunctionGraph
([],
[
out_tt
])
compare_jax_and_py
(
out_fg
,
[])
...
...
@@ -493,12 +500,12 @@ def test_jax_IncSubtensor():
compare_jax_and_py
(
out_fg
,
[])
# "Increment" basic indices
st_tt
=
tt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
t
t
.
config
.
floatX
))
st_tt
=
tt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
t
heano
.
config
.
floatX
))
out_tt
=
tt
.
inc_subtensor
(
x_tt
[
1
,
2
,
3
],
st_tt
)
out_fg
=
theano
.
gof
.
FunctionGraph
([],
[
out_tt
])
compare_jax_and_py
(
out_fg
,
[])
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
t
.
config
.
floatX
))
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
heano
.
config
.
floatX
))
out_tt
=
tt
.
inc_subtensor
(
x_tt
[:
2
,
0
,
0
],
st_tt
)
out_fg
=
theano
.
gof
.
FunctionGraph
([],
[
out_tt
])
compare_jax_and_py
(
out_fg
,
[])
...
...
@@ -508,7 +515,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py
(
out_fg
,
[])
# "Increment" advanced indices
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
t
.
config
.
floatX
))
st_tt
=
tt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
t
heano
.
config
.
floatX
))
out_tt
=
tt
.
inc_subtensor
(
x_tt
[[
0
,
2
],
0
,
0
],
st_tt
)
out_fg
=
theano
.
gof
.
FunctionGraph
([],
[
out_tt
])
compare_jax_and_py
(
out_fg
,
[])
...
...
@@ -545,38 +552,38 @@ def test_jax_ifelse():
def
test_jax_CAReduce
():
a_tt
=
tt
.
vector
(
"a"
)
a_tt
.
tag
.
test_value
=
np
.
r_
[
1
,
2
,
3
]
.
astype
(
t
t
.
config
.
floatX
)
a_tt
.
tag
.
test_value
=
np
.
r_
[
1
,
2
,
3
]
.
astype
(
t
heano
.
config
.
floatX
)
x
=
tt
.
sum
(
a_tt
,
axis
=
None
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1
,
2
,
3
]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1
,
2
,
3
]
.
astype
(
t
heano
.
config
.
floatX
)])
a_tt
=
tt
.
matrix
(
"a"
)
a_tt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
t
.
config
.
floatX
)
a_tt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
heano
.
config
.
floatX
)
x
=
tt
.
sum
(
a_tt
,
axis
=
0
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
heano
.
config
.
floatX
)])
x
=
tt
.
sum
(
a_tt
,
axis
=
1
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
heano
.
config
.
floatX
)])
a_tt
=
tt
.
matrix
(
"a"
)
a_tt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
t
.
config
.
floatX
)
a_tt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
heano
.
config
.
floatX
)
x
=
tt
.
prod
(
a_tt
,
axis
=
0
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
heano
.
config
.
floatX
)])
x
=
tt
.
all
(
a_tt
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
t
heano
.
config
.
floatX
)])
def
test_jax_MakeVector
():
...
...
@@ -615,28 +622,32 @@ def test_jax_Dimshuffle():
x
=
a_tt
.
T
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
tt
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
theano
.
config
.
floatX
)]
)
x
=
a_tt
.
dimshuffle
([
0
,
1
,
"x"
])
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
tt
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
theano
.
config
.
floatX
)]
)
a_tt
=
tt
.
tensor
(
dtype
=
t
t
.
config
.
floatX
,
broadcastable
=
[
False
,
True
])
a_tt
=
tt
.
tensor
(
dtype
=
t
heano
.
config
.
floatX
,
broadcastable
=
[
False
,
True
])
x
=
a_tt
.
dimshuffle
((
0
,))
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
t
heano
.
config
.
floatX
)])
a_tt
=
tt
.
tensor
(
dtype
=
t
t
.
config
.
floatX
,
broadcastable
=
[
False
,
True
])
a_tt
=
tt
.
tensor
(
dtype
=
t
heano
.
config
.
floatX
,
broadcastable
=
[
False
,
True
])
x
=
tt
.
elemwise
.
DimShuffle
([
False
,
True
],
(
0
,),
inplace
=
True
)(
a_tt
)
x_fg
=
theano
.
gof
.
FunctionGraph
([
a_tt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
t
t
.
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
t
heano
.
config
.
floatX
)])
def
test_jax_variadic_Scalar
():
mu
=
tt
.
vector
(
"mu"
,
dtype
=
t
t
.
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
t
t
.
config
.
floatX
)
tau
=
tt
.
vector
(
"tau"
,
dtype
=
t
t
.
config
.
floatX
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
t
t
.
config
.
floatX
)
mu
=
tt
.
vector
(
"mu"
,
dtype
=
t
heano
.
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
t
heano
.
config
.
floatX
)
tau
=
tt
.
vector
(
"tau"
,
dtype
=
t
heano
.
config
.
floatX
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
t
heano
.
config
.
floatX
)
res
=
-
tau
*
mu
...
...
@@ -654,13 +665,13 @@ def test_jax_variadic_Scalar():
def
test_jax_logp
():
mu
=
tt
.
vector
(
"mu"
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.0
,
0.0
]
.
astype
(
t
t
.
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.0
,
0.0
]
.
astype
(
t
heano
.
config
.
floatX
)
tau
=
tt
.
vector
(
"tau"
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
1.0
]
.
astype
(
t
t
.
config
.
floatX
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
1.0
]
.
astype
(
t
heano
.
config
.
floatX
)
sigma
=
tt
.
vector
(
"sigma"
)
sigma
.
tag
.
test_value
=
(
1.0
/
get_test_value
(
tau
))
.
astype
(
t
t
.
config
.
floatX
)
sigma
.
tag
.
test_value
=
(
1.0
/
get_test_value
(
tau
))
.
astype
(
t
heano
.
config
.
floatX
)
value
=
tt
.
vector
(
"value"
)
value
.
tag
.
test_value
=
np
.
r_
[
0.1
,
-
10
]
.
astype
(
t
t
.
config
.
floatX
)
value
.
tag
.
test_value
=
np
.
r_
[
0.1
,
-
10
]
.
astype
(
t
heano
.
config
.
floatX
)
logp
=
(
-
tau
*
(
value
-
mu
)
**
2
+
tt
.
log
(
tau
/
np
.
pi
/
2.0
))
/
2.0
conditions
=
[
sigma
>
0
]
...
...
@@ -674,9 +685,9 @@ def test_jax_logp():
def
test_jax_multioutput
():
x
=
tt
.
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
t
t
.
config
.
floatX
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
t
heano
.
config
.
floatX
)
y
=
tt
.
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
t
t
.
config
.
floatX
)
y
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
t
heano
.
config
.
floatX
)
w
=
tt
.
cosh
(
x
**
2
+
y
/
3.0
)
v
=
tt
.
cosh
(
x
/
3.0
+
y
**
2
)
...
...
@@ -688,7 +699,7 @@ def test_jax_multioutput():
def
test_nnet
():
x
=
tt
.
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
t
t
.
config
.
floatX
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
t
heano
.
config
.
floatX
)
out
=
tt
.
nnet
.
sigmoid
(
x
)
fgraph
=
theano
.
gof
.
FunctionGraph
([
x
],
[
out
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论