Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
58cb5c30
提交
58cb5c30
authored
11月 24, 2021
作者:
Ricardo
提交者:
Brandon T. Willard
12月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add axis to Softmax and SoftmaxGrad Ops
上级
c6c85acb
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
42 行增加
和
22 行删除
+42
-22
dispatch.py
aesara/link/jax/dispatch.py
+3
-1
elemwise.py
aesara/link/numba/dispatch/elemwise.py
+13
-6
__init__.py
aesara/tensor/nnet/__init__.py
+2
-2
basic.py
aesara/tensor/nnet/basic.py
+0
-0
test_dnn.py
tests/gpuarray/test_dnn.py
+3
-3
test_nnet.py
tests/gpuarray/test_nnet.py
+3
-3
test_jax.py
tests/link/test_jax.py
+7
-2
test_numba.py
tests/link/test_numba.py
+10
-3
test_basic.py
tests/tensor/nnet/test_basic.py
+0
-0
test_rop.py
tests/test_rop.py
+1
-2
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
58cb5c30
...
...
@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs):
@jax_funcify.register
(
Softmax
)
def
jax_funcify_Softmax
(
op
,
**
kwargs
):
axis
=
op
.
axis
def
softmax
(
x
):
return
jax
.
nn
.
softmax
(
x
)
return
jax
.
nn
.
softmax
(
x
,
axis
=
axis
)
return
softmax
...
...
aesara/link/numba/dispatch/elemwise.py
浏览文件 @
58cb5c30
...
...
@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_at
=
node
.
inputs
[
0
]
x_dtype
=
x_at
.
type
.
numpy_dtype
x_dtype
=
numba
.
np
.
numpy_support
.
from_dtype
(
x_dtype
)
axis
=
op
.
axis
# np.max(x, axis=1)
reduce_max
=
create_axis_reducer
(
np
.
maximum
,
-
np
.
inf
,
1
,
x_at
.
ndim
,
x_dtype
)
# np.sum(x, axis=1)
reduce_sum
=
create_axis_reducer
(
np
.
add
,
0.0
,
1
,
x_at
.
ndim
,
x_dtype
)
if
axis
is
not
None
:
reduce_max
=
create_axis_reducer
(
np
.
maximum
,
-
np
.
inf
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
reduce_sum
=
create_axis_reducer
(
np
.
add
,
0.0
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
else
:
reduce_max
=
np
.
max
reduce_sum
=
np
.
sum
@numba.njit
def
softmax
(
x
):
z
=
np
.
expand_dims
(
reduce_max
(
x
),
-
1
)
z
=
reduce_max
(
x
)
e_x
=
np
.
exp
(
x
-
z
)
w
=
np
.
expand_dims
(
reduce_sum
(
e_x
),
-
1
)
w
=
reduce_sum
(
e_x
)
sm
=
e_x
/
w
return
sm
...
...
aesara/tensor/nnet/__init__.py
浏览文件 @
58cb5c30
...
...
@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import (
selu
,
sigmoid_binary_crossentropy
,
softmax
,
softmax_grad
,
softmax_grad
_legacy
,
softmax_graph
,
softmax_
op
,
softmax_
legacy
,
softmax_simplifier
,
softmax_with_bias
,
softsign
,
...
...
aesara/tensor/nnet/basic.py
浏览文件 @
58cb5c30
差异被折叠。
点击展开。
tests/gpuarray/test_dnn.py
浏览文件 @
58cb5c30
...
...
@@ -32,7 +32,7 @@ from aesara.tensor.math import (
sqrt
,
)
from
aesara.tensor.math
import
sum
as
aet_sum
from
aesara.tensor.nnet
import
batchnorm
,
conv2d
,
softmax
,
softmax_
op
from
aesara.tensor.nnet
import
batchnorm
,
conv2d
,
softmax
,
softmax_
legacy
from
aesara.tensor.nnet.abstract_conv
import
(
get_conv_gradinputs_shape
,
get_conv_output_shape
,
...
...
@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
def
test_softmax_f16
(
self
):
x
=
matrix
(
"x"
,
"float16"
)
x_gpu
=
tensor4
(
"x_gpu"
,
"float16"
)
f_z
=
softmax_
op
f_z
=
softmax_
legacy
f_gpu
=
dnn
.
GpuDnnSoftmax
(
"accurate"
,
"channel"
)
def
cmp
(
n
,
m
,
f
,
f_gpu
):
...
...
@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
x
=
matrix
(
"x"
)
x_gpu
=
tensor4
(
"x_gpu"
)
f_z
=
softmax_
op
f_z
=
softmax_
legacy
f_gpu
=
dnn
.
GpuDnnSoftmax
(
"accurate"
,
"channel"
)
# Verify the grad operation
...
...
tests/gpuarray/test_nnet.py
浏览文件 @
58cb5c30
...
...
@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput):
z
=
aesara
.
tensor
.
nnet
.
softmax
(
x
)
f
=
aesara
.
function
([
x
],
z
,
mode
=
mode_without_gpu
)
f_gpu
=
aesara
.
function
([
x
],
z
,
mode
=
mode_wo_cudnn
)
assert
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
==
aesara
.
tensor
.
nnet
.
softmax_
op
assert
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
==
aesara
.
tensor
.
nnet
.
softmax_
legacy
assert
isinstance
(
f_gpu
.
maker
.
fgraph
.
toposort
()[
-
2
]
.
op
,
GpuSoftmax
)
def
cmp
(
n
,
m
):
...
...
@@ -300,7 +300,7 @@ class TestSoftMax:
def
test_softmax
(
self
):
x
=
fmatrix
(
"x"
)
z
=
aesara
.
tensor
.
nnet
.
softmax_
op
z
=
aesara
.
tensor
.
nnet
.
softmax_
legacy
f
,
f_gpu
=
self
.
_test_softmax
(
x
,
x
,
z
,
z
,
self
.
_cmp
)
...
...
@@ -308,7 +308,7 @@ class TestSoftMax:
def
test_softmax_shape_0
(
self
):
x
=
fmatrix
(
"x"
)
z
=
aesara
.
tensor
.
nnet
.
softmax_
op
z
=
aesara
.
tensor
.
nnet
.
softmax_
legacy
f
,
f_gpu
=
self
.
_test_softmax
(
x
,
x
,
z
,
z
,
self
.
_cmp
)
# Aesara can handle that case, but cudnn can't
...
...
tests/link/test_jax.py
浏览文件 @
58cb5c30
...
...
@@ -969,11 +969,16 @@ def test_nnet():
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
aet_nnet
.
softmax
(
x
)
out
=
aet_nnet
.
log
softmax
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
aet_nnet
.
logsoftmax
(
x
)
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax
(
axis
):
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
aet_nnet
.
softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
...
...
tests/link/test_numba.py
浏览文件 @
58cb5c30
...
...
@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc):
@pytest.mark.parametrize
(
"x, exc"
,
"x,
axis,
exc"
,
[
(
set_test_value
(
aet
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
None
,
None
,
),
(
set_test_value
(
aet
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
None
,
None
,
),
(
set_test_value
(
aet
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
0
,
None
,
),
],
)
def
test_Softmax
(
x
,
exc
):
g
=
nnetb
.
Softmax
()(
x
)
def
test_Softmax
(
x
,
axis
,
exc
):
g
=
nnetb
.
Softmax
(
axis
=
axis
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
...
...
tests/tensor/nnet/test_basic.py
浏览文件 @
58cb5c30
差异被折叠。
点击展开。
tests/test_rop.py
浏览文件 @
58cb5c30
...
...
@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker):
self
.
check_mat_rop_lop
(
self
.
mx
.
sum
(
axis
=
1
),
(
self
.
mat_in_shape
[
0
],))
def
test_softmax
(
self
):
# Softmax adds an extra dimnesion !
self
.
check_rop_lop
(
aesara
.
tensor
.
nnet
.
softmax
(
self
.
x
)[
0
],
self
.
in_shape
[
0
])
self
.
check_rop_lop
(
aesara
.
tensor
.
nnet
.
softmax
(
self
.
x
),
self
.
in_shape
)
def
test_alloc
(
self
):
# Alloc of the sum of x into a vector
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论