Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
57c388a7
提交
57c388a7
authored
11月 24, 2021
作者:
Ricardo
提交者:
Brandon T. Willard
12月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add axis to LogSoftmax
上级
595ed184
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
49 行增加
和
49 行删除
+49
-49
dispatch.py
aesara/link/jax/dispatch.py
+3
-1
elemwise.py
aesara/link/numba/dispatch/elemwise.py
+12
-5
__init__.py
aesara/tensor/nnet/__init__.py
+0
-1
basic.py
aesara/tensor/nnet/basic.py
+0
-0
test_jax.py
tests/link/test_jax.py
+8
-3
test_numba.py
tests/link/test_numba.py
+10
-3
test_basic.py
tests/tensor/nnet/test_basic.py
+16
-36
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
57c388a7
...
@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs):
...
@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs):
@jax_funcify.register
(
LogSoftmax
)
@jax_funcify.register
(
LogSoftmax
)
def
jax_funcify_LogSoftmax
(
op
,
**
kwargs
):
def
jax_funcify_LogSoftmax
(
op
,
**
kwargs
):
axis
=
op
.
axis
def
log_softmax
(
x
):
def
log_softmax
(
x
):
return
jax
.
nn
.
log_softmax
(
x
)
return
jax
.
nn
.
log_softmax
(
x
,
axis
=
axis
)
return
log_softmax
return
log_softmax
...
...
aesara/link/numba/dispatch/elemwise.py
浏览文件 @
57c388a7
...
@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
...
@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at
=
node
.
inputs
[
0
]
x_at
=
node
.
inputs
[
0
]
x_dtype
=
x_at
.
type
.
numpy_dtype
x_dtype
=
x_at
.
type
.
numpy_dtype
x_dtype
=
numba
.
np
.
numpy_support
.
from_dtype
(
x_dtype
)
x_dtype
=
numba
.
np
.
numpy_support
.
from_dtype
(
x_dtype
)
axis
=
op
.
axis
# np.max(x, axis=1)
if
axis
is
not
None
:
reduce_max
=
create_axis_reducer
(
np
.
maximum
,
-
np
.
inf
,
1
,
x_at
.
ndim
,
x_dtype
)
reduce_max
=
create_axis_reducer
(
# np.sum(x, axis=1, keepdims=True)
np
.
maximum
,
-
np
.
inf
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
reduce_sum
=
create_axis_reducer
(
np
.
add
,
0.0
,
1
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
)
reduce_sum
=
create_axis_reducer
(
np
.
add
,
0.0
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
else
:
reduce_max
=
np
.
max
reduce_sum
=
np
.
sum
@numba.njit
@numba.njit
def
log_softmax
(
x
):
def
log_softmax
(
x
):
xdev
=
x
-
np
.
expand_dims
(
reduce_max
(
x
),
-
1
)
xdev
=
x
-
reduce_max
(
x
)
lsm
=
xdev
-
np
.
log
(
reduce_sum
(
np
.
exp
(
xdev
)))
lsm
=
xdev
-
np
.
log
(
reduce_sum
(
np
.
exp
(
xdev
)))
return
lsm
return
lsm
...
...
aesara/tensor/nnet/__init__.py
浏览文件 @
57c388a7
...
@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import (
...
@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import (
graph_merge_softmax_with_crossentropy_softmax
,
graph_merge_softmax_with_crossentropy_softmax
,
h_softmax
,
h_softmax
,
logsoftmax
,
logsoftmax
,
logsoftmax_op
,
prepend_0_to_each_row
,
prepend_0_to_each_row
,
prepend_1_to_each_row
,
prepend_1_to_each_row
,
prepend_scalar_to_each_row
,
prepend_scalar_to_each_row
,
...
...
aesara/tensor/nnet/basic.py
浏览文件 @
57c388a7
差异被折叠。
点击展开。
tests/link/test_jax.py
浏览文件 @
57c388a7
...
@@ -969,16 +969,21 @@ def test_nnet():
...
@@ -969,16 +969,21 @@ def test_nnet():
fgraph
=
FunctionGraph
([
x
],
[
out
])
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
aet_nnet
.
logsoftmax
(
x
)
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax
(
axis
):
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
aet_nnet
.
softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax
(
axis
):
def
test_
log
softmax
(
axis
):
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
aet_nnet
.
softmax
(
x
,
axis
=
axis
)
out
=
aet_nnet
.
log
softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
...
...
tests/link/test_numba.py
浏览文件 @
57c388a7
...
@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc):
...
@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, exc"
,
"x,
axis,
exc"
,
[
[
(
(
set_test_value
(
aet
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
set_test_value
(
aet
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
None
,
None
,
None
,
),
),
(
(
set_test_value
(
aet
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
set_test_value
(
aet
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
0
,
None
,
),
(
set_test_value
(
aet
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
1
,
None
,
None
,
),
),
],
],
)
)
def
test_LogSoftmax
(
x
,
exc
):
def
test_LogSoftmax
(
x
,
axis
,
exc
):
g
=
nnetb
.
LogSoftmax
()(
x
)
g
=
nnetb
.
LogSoftmax
(
axis
=
axis
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
...
...
tests/tensor/nnet/test_basic.py
浏览文件 @
57c388a7
...
@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import (
...
@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import (
elu
,
elu
,
h_softmax
,
h_softmax
,
logsoftmax
,
logsoftmax
,
logsoftmax_op
,
relu
,
relu
,
selu
,
selu
,
sigmoid_binary_crossentropy
,
sigmoid_binary_crossentropy
,
...
@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
...
@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
class
TestLogSoftmax
(
utt
.
InferShapeTester
):
class
TestLogSoftmax
(
utt
.
InferShapeTester
):
def
test_basic
(
self
):
@pytest.mark.parametrize
(
"column"
,
[
0
,
1
,
2
,
3
])
def
f
(
a
):
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
return
logsoftmax_op
(
a
)[:,
0
]
def
test_matrix_grad
(
self
,
axis
,
column
):
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
3
,
4
))])
def
f
(
a
):
return
logsoftmax_op
(
a
)[:,
1
]
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
3
,
4
))])
def
f
(
a
):
return
logsoftmax_op
(
a
)[:,
2
]
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
3
,
4
))])
def
f
(
a
):
return
logsoftmax_op
(
a
)[:,
3
]
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
3
,
4
))])
def
test_matrix
(
self
):
def
f
(
a
):
def
f
(
a
):
return
logsoftmax
_op
(
a
)
return
logsoftmax
(
a
,
axis
=
axis
)[:,
column
]
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
3
,
4
))])
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
3
,
4
))])
def
test_vector
(
self
):
def
test_vector
_perform
(
self
):
x
=
vector
()
x
=
vector
()
f
=
aesara
.
function
([
x
],
logsoftmax
_op
(
x
))
f
=
aesara
.
function
([
x
],
logsoftmax
(
x
,
axis
=
None
))
xv
=
np
.
random
.
randn
(
6
)
.
astype
(
config
.
floatX
)
xv
=
np
.
random
.
randn
(
6
)
.
astype
(
config
.
floatX
)
assert
np
.
allclose
(
f
(
xv
),
np
.
log
(
np
.
exp
(
xv
)
/
np
.
exp
(
xv
)
.
sum
()
))
assert
np
.
allclose
(
f
(
xv
),
sp
.
log_softmax
(
xv
))
def
test_vector_grad
(
self
):
def
test_vector_grad
(
self
):
def
f
(
a
):
def
f
(
a
):
return
logsoftmax
_op
(
a
)
return
logsoftmax
(
a
,
axis
=
None
)
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
4
))])
utt
.
verify_grad
(
f
,
[
np
.
random
.
random
((
4
))])
def
test_
allclose
(
self
):
def
test_
matrix_perform_and_opt
(
self
):
m
=
config
.
mode
m
=
config
.
mode
m
=
aesara
.
compile
.
get_mode
(
m
)
m
=
aesara
.
compile
.
get_mode
(
m
)
m
.
check_isfinite
=
False
m
.
check_isfinite
=
False
...
@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester):
...
@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester):
grad_
=
f3
(
a
,
b
)
grad_
=
f3
(
a
,
b
)
assert
not
np
.
any
(
np
.
isnan
(
grad_
))
assert
not
np
.
any
(
np
.
isnan
(
grad_
))
def
test_isclose
(
self
):
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
-
1
])
def
f
(
a
):
def
test_local_logsoftmax_opt
(
self
,
axis
):
return
logsoftmax_op
(
a
)
def
test_local_softmax_optimization
(
self
):
# Test the Logsoftmax substitution
# Test the Logsoftmax substitution
#
#
# Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that
# Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that
# only the forward pass is checked (i.e., doesn't check the gradient)
# only the forward pass is checked (i.e., doesn't check the gradient)
x
,
y
=
matrices
(
"xy
"
)
x
=
matrix
(
"x
"
)
sm
=
softmax
(
x
)
sm
=
softmax
(
x
,
axis
=
axis
)
logsm
=
log
(
sm
)
logsm
=
log
(
sm
)
f
=
aesara
.
function
([
x
],
logsm
)
f
=
aesara
.
function
([
x
],
logsm
)
assert
isinstance
(
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
,
LogSoftmax
)
assert
isinstance
(
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
,
LogSoftmax
)
...
@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester):
...
@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester):
assert
softmax_grad_legacy
in
[
n
.
op
for
n
in
fgraph
.
toposort
()]
assert
softmax_grad_legacy
in
[
n
.
op
for
n
in
fgraph
.
toposort
()]
def
test_valid_axis
(
self
):
valid_axis_tester
(
LogSoftmax
)
class
TestSoftmaxGrad
(
utt
.
InferShapeTester
):
class
TestSoftmaxGrad
(
utt
.
InferShapeTester
):
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论