Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f2ad711f
提交
f2ad711f
authored
11月 04, 2024
作者:
ricardoV94
提交者:
Ricardo Vieira
11月 11, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement unconditional constant_folding rewrite
上级
a570dbfd
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
105 行增加
和
50 行删除
+105
-50
basic.py
pytensor/tensor/rewriting/basic.py
+19
-4
test_basic.py
tests/tensor/rewriting/test_basic.py
+86
-46
没有找到文件。
pytensor/tensor/rewriting/basic.py
浏览文件 @
f2ad711f
...
...
@@ -32,6 +32,7 @@ from pytensor.compile.ops import ViewOp
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph.basic
import
Constant
,
Variable
from
pytensor.graph.rewriting.basic
import
(
NodeProcessingGraphRewriter
,
NodeRewriter
,
RemovalNodeRewriter
,
Rewriter
,
...
...
@@ -1101,10 +1102,7 @@ def local_useless_split(fgraph, node):
@node_rewriter
(
None
)
def
constant_folding
(
fgraph
,
node
):
if
not
node
.
op
.
do_constant_folding
(
fgraph
,
node
):
return
False
def
unconditional_constant_folding
(
fgraph
,
node
):
if
not
all
(
isinstance
(
inp
,
Constant
)
for
inp
in
node
.
inputs
):
return
False
...
...
@@ -1151,6 +1149,23 @@ def constant_folding(fgraph, node):
return
rval
topo_unconditional_constant_folding
=
in2out
(
unconditional_constant_folding
,
ignore_newtrees
=
True
,
name
=
"topo_unconditional_constant_folding"
,
# Not all Ops have a perform method, so we ignore failures to constant_fold
failure_callback
=
NodeProcessingGraphRewriter
.
warn_ignore
,
)
@node_rewriter
(
None
)
def
constant_folding
(
fgraph
,
node
):
if
not
node
.
op
.
do_constant_folding
(
fgraph
,
node
):
return
False
return
unconditional_constant_folding
.
transform
(
fgraph
,
node
)
topo_constant_folding
=
in2out
(
constant_folding
,
ignore_newtrees
=
True
,
name
=
"topo_constant_folding"
)
...
...
tests/tensor/rewriting/test_basic.py
浏览文件 @
f2ad711f
...
...
@@ -12,7 +12,8 @@ from pytensor.compile.function import function
from
pytensor.compile.mode
import
get_default_mode
,
get_mode
from
pytensor.compile.ops
import
DeepCopyOp
,
deep_copy_op
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph
import
Op
from
pytensor.graph.basic
import
Constant
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.rewriting.basic
import
check_stack_trace
,
out2in
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
...
...
@@ -29,6 +30,7 @@ from pytensor.tensor.basic import (
TensorFromScalar
,
as_tensor
,
cast
,
constant
,
join
,
tile
,
)
...
...
@@ -65,6 +67,8 @@ from pytensor.tensor.rewriting.basic import (
local_merge_alloc
,
local_useless_alloc
,
local_useless_elemwise
,
topo_constant_folding
,
topo_unconditional_constant_folding
,
topological_fill_sink
,
)
from
pytensor.tensor.rewriting.math
import
local_lift_transpose_through_dot
...
...
@@ -742,56 +746,92 @@ class TestCastCast:
)
or
(
len
(
topo
)
>
1
)
def
test_constant_folding
():
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x
=
dvector
()
mode
=
get_mode
(
"FAST_COMPILE"
)
.
excluding
(
"fusion"
)
f
=
function
([
x
],
[
x
*
2
,
x
+
x
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
class
TestConstantFolding
:
def
test_constant_folding
(
self
):
# Test that constant folding get registered at fast_compile
# An error removed that registration during the registration.
x
=
dvector
()
mode
=
get_mode
(
"FAST_COMPILE"
)
.
excluding
(
"fusion"
)
f
=
function
([
x
],
[
x
*
2
,
x
+
x
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
x
=
pt
.
constant
(
3
)
assert
x
.
ndim
==
0
mode
=
get_mode
(
"FAST_COMPILE"
)
.
excluding
(
"fusion"
)
f
=
function
([],
[
x
*
2
,
x
+
x
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
assert
all
(
isinstance
(
n
.
op
,
DeepCopyOp
)
for
n
in
topo
)
# Test that we do not crash when constant folding elemwise scalar
# as they should not generate c code.
x
=
pt
.
constant
(
3
)
assert
x
.
ndim
==
0
mode
=
get_mode
(
"FAST_COMPILE"
)
.
excluding
(
"fusion"
)
f
=
function
([],
[
x
*
2
,
x
+
x
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
2
assert
all
(
isinstance
(
n
.
op
,
DeepCopyOp
)
for
n
in
topo
)
@pytest.mark.xfail
(
reason
=
"PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504."
,
raises
=
AssertionError
,
)
def
test_constant_get_stabilized
(
):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.
@pytest.mark.xfail
(
reason
=
"PyTensor rewrites constants before stabilization. "
"This breaks stabilization rewrites in some cases. See #504."
,
raises
=
AssertionError
,
)
def
test_constant_get_stabilized
(
self
):
# Currently PyTensor enables the `constant_folding` rewrite before stabilization rewrites.
# This caused some stabilization rewrites to not be activated and that
# caused inf values to appear when they should not.
# We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue.
# We can't simply move the `constant_folding` rewrite to
# specialize since this will break other rewrites. We will need to
# partially duplicate some canonicalize rewrites to fix this issue.
x2
=
scalar
()
y2
=
log
(
1
+
exp
(
x2
))
mode
=
get_default_mode
()
mode
.
check_isfinite
=
False
f2
=
function
([
x2
],
y2
,
mode
=
mode
)
assert
len
(
f2
.
maker
.
fgraph
.
toposort
())
==
1
assert
f2
.
maker
.
fgraph
.
toposort
()[
0
]
.
op
==
softplus
assert
f2
(
800
)
==
800
x
=
pt
.
as_tensor_variable
(
800
)
y
=
log
(
1
+
exp
(
x
))
f
=
function
([],
y
,
mode
=
mode
)
# When this error is fixed, the following line should be ok.
assert
f
()
==
800
,
f
()
x2
=
scalar
()
y2
=
log
(
1
+
exp
(
x2
))
mode
=
get_default_mode
()
mode
.
check_isfinite
=
False
f2
=
function
([
x2
],
y2
,
mode
=
mode
)
assert
len
(
f2
.
maker
.
fgraph
.
toposort
())
==
1
assert
f2
.
maker
.
fgraph
.
toposort
()[
0
]
.
op
==
softplus
assert
f2
(
800
)
==
800
x
=
pt
.
as_tensor_variable
(
800
)
y
=
log
(
1
+
exp
(
x
))
f
=
function
([],
y
,
mode
=
mode
)
# When this error is fixed, the following line should be ok.
assert
f
()
==
800
,
f
()
def
test_unconditional
(
self
):
x
=
pt
.
alloc
(
np
.
e
,
*
(
3
,
5
))
fg
=
FunctionGraph
(
outputs
=
[
x
],
clone
=
False
)
# Default constant folding doesn't apply to Alloc used as outputs
topo_constant_folding
.
apply
(
fg
)
assert
not
isinstance
(
fg
.
outputs
[
0
],
Constant
)
# Unconditional constant folding does apply
topo_unconditional_constant_folding
.
apply
(
fg
)
assert
isinstance
(
fg
.
outputs
[
0
],
Constant
)
np
.
testing
.
assert_allclose
(
fg
.
outputs
[
0
]
.
data
,
np
.
full
((
3
,
5
),
np
.
e
))
def
test_unconditional_no_perform_method
(
self
):
"""Test that errors are caught when the Op does not have a perform method."""
class
OpNoPerform
(
Op
):
itypes
=
[
scalar
(
dtype
=
"float64"
)
.
type
]
otypes
=
[
scalar
(
dtype
=
"float64"
)
.
type
]
def
perform
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"This Op cannot be evaluated"
)
x
=
constant
(
np
.
array
(
5.0
))
out
=
OpNoPerform
()(
x
)
fg
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
False
)
# Default constant_folding will raise
with
pytest
.
raises
(
NotImplementedError
):
topo_constant_folding
.
apply
(
fg
)
# Unconditional constant folding will be silent
topo_unconditional_constant_folding
.
apply
(
fg
)
assert
not
isinstance
(
fg
.
outputs
[
0
],
Constant
)
assert
isinstance
(
fg
.
outputs
[
0
]
.
owner
.
op
,
OpNoPerform
)
class
TestLocalSwitchSink
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论