Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ce0b503c
提交
ce0b503c
authored
6月 20, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extend log_softmax rewrite and run it in `stabilize`
上级
39bda72a
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
89 行增加
和
51 行删除
+89
-51
special.py
pytensor/tensor/rewriting/special.py
+56
-27
test_special.py
tests/tensor/rewriting/test_special.py
+33
-24
没有找到文件。
pytensor/tensor/rewriting/special.py
浏览文件 @
ce0b503c
from
pytensor
import
scalar
as
aes
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
node_rewriter
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
Sum
,
exp
from
pytensor.tensor.math
import
Sum
,
exp
,
log
from
pytensor.tensor.math
import
sum
as
at_sum
from
pytensor.tensor.math
import
sum
as
at_sum
from
pytensor.tensor.math
import
true_div
from
pytensor.tensor.math
import
true_div
from
pytensor.tensor.rewriting.basic
import
register_s
pecia
lize
from
pytensor.tensor.rewriting.basic
import
register_s
tabi
lize
from
pytensor.tensor.rewriting.math
import
local_mul_canonizer
from
pytensor.tensor.rewriting.math
import
local_mul_canonizer
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
pytensor.tensor.special
import
Softmax
,
SoftmaxGrad
,
log_softmax
from
pytensor.tensor.subtensor
import
AdvancedIncSubtensor
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
Subtensor
,
)
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
values_eq_approx_remove_inf
,
values_eq_approx_remove_inf
,
values_eq_approx_remove_nan
,
values_eq_approx_remove_nan
,
)
)
# This is not registered in stabilize, as it cause some crossentropy
subtensor_ops
=
(
# optimization to not be inserted.
Subtensor
,
@register_specialize
(
"stabilize"
,
"fast_compile"
)
AdvancedSubtensor
,
@node_rewriter
([
Elemwise
])
AdvancedSubtensor1
,
)
@register_stabilize
@node_rewriter
([
log
])
def
local_logsoftmax
(
fgraph
,
node
):
def
local_logsoftmax
(
fgraph
,
node
):
"""
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
Note: only forward pass is affected
Note: only forward pass is affected
"""
"""
if
(
isinstance
(
node
.
op
,
Elemwise
)
def
find_softmax_under_lifteable_ops
(
inp_node
,
ops_to_lift
):
and
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Log
)
if
inp_node
is
None
:
and
len
(
node
.
inputs
)
==
1
return
and
node
.
inputs
[
0
]
.
owner
is
not
None
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Softmax
)
if
isinstance
(
inp_node
.
op
,
Softmax
):
):
return
inp_node
inVars
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
new_op
=
LogSoftmax
(
axis
=
node
.
inputs
[
0
]
.
owner
.
op
.
axis
)
if
isinstance
(
inp_node
.
op
,
subtensor_ops
):
ret
=
new_op
(
inVars
)
ops_to_lift
.
append
((
inp_node
.
op
,
inp_node
.
inputs
[
1
:]))
return
find_softmax_under_lifteable_ops
(
inp_node
.
inputs
[
0
]
.
owner
,
ops_to_lift
)
if
isinstance
(
inp_node
.
op
,
DimShuffle
):
ops_to_lift
.
append
((
inp_node
.
op
,
()))
return
find_softmax_under_lifteable_ops
(
inp_node
.
inputs
[
0
]
.
owner
,
ops_to_lift
)
ops_to_lift
=
[]
softmax_node
=
find_softmax_under_lifteable_ops
(
node
.
inputs
[
0
]
.
owner
,
ops_to_lift
)
if
softmax_node
is
None
:
return
ret
=
log_softmax
(
softmax_node
.
inputs
[
0
],
axis
=
softmax_node
.
op
.
axis
)
ret
.
tag
.
values_eq_approx
=
values_eq_approx_remove_inf
ret
.
tag
.
values_eq_approx
=
values_eq_approx_remove_inf
copy_stack_trace
([
node
.
inputs
[
0
],
node
.
outputs
[
0
]],
ret
)
# Lift ops that used to be between log and softmax
for
op_to_lift
,
parameters
in
reversed
(
ops_to_lift
):
ret
=
op_to_lift
(
ret
,
*
parameters
)
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
return
[
ret
]
# This is not registered in stabilize, as it cause some crossentropy
@register_stabilize
# optimization to not be inserted.
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@node_rewriter
([
SoftmaxGrad
])
@node_rewriter
([
SoftmaxGrad
])
def
local_logsoftmax_grad
(
fgraph
,
node
):
def
local_logsoftmax_grad
(
fgraph
,
node
):
"""
"""
...
@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
...
@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
Note: only grad is affected
Note: only grad is affected
"""
"""
if
(
if
(
isinstance
(
node
.
op
,
SoftmaxGrad
)
node
.
inputs
[
0
]
.
owner
is
not
None
and
len
(
node
.
inputs
)
==
2
and
node
.
inputs
[
0
]
.
owner
is
not
None
and
node
.
inputs
[
0
]
.
owner
.
op
==
true_div
and
node
.
inputs
[
0
]
.
owner
.
op
==
true_div
and
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
>=
2
and
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
>=
2
and
node
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
is
not
None
and
node
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
is
not
None
...
...
tests/tensor/rewriting/test_special.py
浏览文件 @
ce0b503c
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
scipy.special
import
pytensor
import
pytensor
from
pytensor
import
shared
from
pytensor
import
shared
...
@@ -35,6 +36,37 @@ class TestLogSoftmaxRewrites:
...
@@ -35,6 +36,37 @@ class TestLogSoftmaxRewrites:
_fast_run_rewrites
.
rewrite
(
fgraph
)
_fast_run_rewrites
.
rewrite
(
fgraph
)
assert
isinstance
(
fgraph
.
outputs
[
0
]
.
owner
.
op
,
LogSoftmax
)
assert
isinstance
(
fgraph
.
outputs
[
0
]
.
owner
.
op
,
LogSoftmax
)
assert
check_stack_trace
(
fgraph
,
ops_to_check
=
LogSoftmax
)
assert
check_stack_trace
(
fgraph
,
ops_to_check
=
LogSoftmax
)
assert
check_stack_trace
(
fgraph
,
ops_to_check
=
"all"
)
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
-
1
])
@pytest.mark.parametrize
(
"idx0"
,
[
0
,
slice
(
1
,
None
),
slice
(
None
)])
@pytest.mark.parametrize
(
"idx1"
,
[
None
,
[
0
,
1
,
1
,
-
1
]])
def
test_logsoftmax_subtensor_dimshuffle
(
self
,
axis
,
idx0
,
idx1
):
"""Test that stabilization is introduced even when subtensor or dimshuffle operations
are present between log and softmax.
"""
logit_p
=
matrix
(
"logit_p"
)
p
=
softmax
(
logit_p
,
axis
=
axis
)
p_indexed
=
p
[(
idx0
,
idx1
)]
out
=
log
(
p_indexed
)
# Don't waste time with C compilation
with
config
.
change_flags
(
cxx
=
""
):
mode
=
get_mode
(
None
)
.
including
(
"stabilize"
)
fn
=
pytensor
.
function
([
logit_p
],
out
,
mode
=
mode
)
assert
not
any
(
isinstance
(
node
.
op
,
Softmax
)
for
node
in
fn
.
maker
.
fgraph
.
apply_nodes
)
# This range would lead to underflow to -inf without the stabilization
test_logit_p
=
np
.
array
(
[[
-
10.0
,
-
10.0
,
999.0
],
[
999.0
,
990.0
,
-
10.0
]],
dtype
=
config
.
floatX
)
np
.
testing
.
assert_allclose
(
fn
(
logit_p
=
test_logit_p
),
scipy
.
special
.
log_softmax
(
test_logit_p
,
axis
=
axis
)[(
idx0
,
idx1
)],
)
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
-
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
-
1
])
def
test_local_logsoftmax_grad_rewrite
(
self
,
axis
):
def
test_local_logsoftmax_grad_rewrite
(
self
,
axis
):
...
@@ -46,7 +78,7 @@ class TestLogSoftmaxRewrites:
...
@@ -46,7 +78,7 @@ class TestLogSoftmaxRewrites:
"""
"""
m
=
config
.
mode
m
=
config
.
mode
m
=
get_mode
(
m
)
m
=
get_mode
(
m
)
.
including
(
"stabilize"
)
m
.
check_isfinite
=
False
m
.
check_isfinite
=
False
# some inputs that are large to make the gradient explode in the non
# some inputs that are large to make the gradient explode in the non
# rewritten case
# rewritten case
...
@@ -91,29 +123,6 @@ class TestLogSoftmaxRewrites:
...
@@ -91,29 +123,6 @@ class TestLogSoftmaxRewrites:
assert
SoftmaxGrad
(
axis
=-
1
)
in
[
n
.
op
for
n
in
fgraph
.
toposort
()]
assert
SoftmaxGrad
(
axis
=-
1
)
in
[
n
.
op
for
n
in
fgraph
.
toposort
()]
def
test_log_softmax_stabilization
():
mode
=
pytensor
.
compile
.
mode
.
get_default_mode
()
mode
=
mode
.
including
(
"local_log_softmax"
,
"specialize"
)
x
=
matrix
()
y
=
softmax
(
x
,
axis
=-
1
)
z
=
log
(
y
)
fgraph
=
FunctionGraph
([
x
],
[
z
])
_fast_run_rewrites
(
fgraph
)
assert
check_stack_trace
(
fgraph
,
ops_to_check
=
"all"
)
# Check that the softmax has been rewritten
for
node
in
fgraph
.
toposort
():
assert
not
isinstance
(
node
.
op
,
Softmax
)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
f
=
pytensor
.
function
([
x
],
z
,
mode
=
mode
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
f
(
np
.
cast
[
config
.
floatX
](
rng
.
random
((
2
,
3
))))
def
test_softmax_graph
():
def
test_softmax_graph
():
"""Make sure that sotfmax expressions are turned into
"""Make sure that sotfmax expressions are turned into
a softmax Op.
a softmax Op.
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论