Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6cd90ee9
提交
6cd90ee9
authored
7月 04, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 09, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix RaiseAndCheck C implementation with tensor conditions.
For performance, the Op now always converts the inputs to boolean scalars. Also do not constant-fold if it would raise.
上级
b9fc4f8e
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
66 行增加
和
73 行删除
+66
-73
basic.py
pytensor/link/jax/dispatch/basic.py
+8
-4
basic.py
pytensor/link/pytorch/dispatch/basic.py
+10
-1
raise_op.py
pytensor/raise_op.py
+17
-46
basic.py
pytensor/tensor/rewriting/basic.py
+2
-7
test_basic.py
tests/tensor/rewriting/test_basic.py
+6
-6
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1
-1
test_extra_ops.py
tests/tensor/test_extra_ops.py
+10
-3
test_raise_op.py
tests/test_raise_op.py
+12
-5
没有找到文件。
pytensor/link/jax/dispatch/basic.py
浏览文件 @
6cd90ee9
...
...
@@ -10,10 +10,11 @@ from pytensor.compile import JAX
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.ops
import
DeepCopyOp
,
TypeCastingOp
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.ifelse
import
IfElse
from
pytensor.link.utils
import
fgraph_to_python
from
pytensor.raise_op
import
Assert
,
CheckAndRaise
from
pytensor.raise_op
import
CheckAndRaise
if
config
.
floatX
==
"float64"
:
...
...
@@ -73,11 +74,14 @@ def jax_funcify_IfElse(op, **kwargs):
return
ifelse
@jax_funcify.register
(
Assert
)
@jax_funcify.register
(
CheckAndRaise
)
def
jax_funcify_CheckAndRaise
(
op
,
**
kwargs
):
def
jax_funcify_CheckAndRaise
(
op
,
node
,
**
kwargs
):
conds
=
node
.
inputs
[
1
:]
if
any
(
isinstance
(
cond
,
Constant
)
and
not
bool
(
cond
.
data
)
for
cond
in
conds
):
raise
op
.
exc_type
(
op
.
msg
)
warnings
.
warn
(
f
"""Skipping
`CheckAndRaise`
Op (assertion: {op.msg}) as JAX tracing would remove it."""
,
f
"""Skipping
{op}
Op (assertion: {op.msg}) as JAX tracing would remove it."""
,
stacklevel
=
2
,
)
...
...
pytensor/link/pytorch/dispatch/basic.py
浏览文件 @
6cd90ee9
...
...
@@ -22,6 +22,7 @@ from pytensor.tensor.basic import (
Eye
,
Join
,
MakeVector
,
ScalarFromTensor
,
Split
,
TensorFromScalar
,
)
...
...
@@ -79,6 +80,14 @@ def pytorch_funcify_CastingOp(op, node, **kwargs):
return
type_cast
@pytorch_funcify.register
(
ScalarFromTensor
)
def
pytorch_funcify_ScalarFromTensor
(
op
,
node
,
**
kwargs
):
def
scalar_from_tensor
(
x
):
return
x
[()]
return
scalar_from_tensor
@pytorch_funcify.register
(
CheckAndRaise
)
def
pytorch_funcify_CheckAndRaise
(
op
,
**
kwargs
):
error
=
op
.
exc_type
...
...
@@ -86,7 +95,7 @@ def pytorch_funcify_CheckAndRaise(op, **kwargs):
def
assert_fn
(
x
,
*
conditions
):
for
cond
in
conditions
:
if
not
cond
.
item
()
:
if
not
cond
:
raise
error
(
msg
)
return
x
...
...
pytensor/raise_op.py
浏览文件 @
6cd90ee9
...
...
@@ -2,15 +2,13 @@
from
textwrap
import
indent
import
numpy
as
np
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.link.c.type
import
Generic
from
pytensor.scalar.basic
import
ScalarType
from
pytensor.scalar.basic
import
ScalarType
,
as_scalar
from
pytensor.tensor.type
import
DenseTensorType
...
...
@@ -56,18 +54,6 @@ class CheckAndRaise(COp):
msg
=
self
.
msg
return
f
"{name}{{raises={exc_name}, msg='{msg}'}}"
def
__eq__
(
self
,
other
):
if
type
(
self
)
is
not
type
(
other
):
return
False
if
self
.
msg
==
other
.
msg
and
self
.
exc_type
==
other
.
exc_type
:
return
True
return
False
def
__hash__
(
self
):
return
hash
((
self
.
msg
,
self
.
exc_type
))
def
make_node
(
self
,
value
:
Variable
,
*
conds
:
Variable
):
"""
...
...
@@ -84,12 +70,10 @@ class CheckAndRaise(COp):
if
not
isinstance
(
value
,
Variable
):
value
=
pt
.
as_tensor_variable
(
value
)
conds
=
[
pt
.
as_tensor_variable
(
c
)
if
not
isinstance
(
c
,
Variable
)
else
c
for
c
in
conds
]
assert
all
(
c
.
type
.
ndim
==
0
for
c
in
conds
)
conds
=
[
as_scalar
(
c
)
for
c
in
conds
]
for
i
,
cond
in
enumerate
(
conds
):
if
cond
.
dtype
!=
"bool"
:
conds
[
i
]
=
cond
.
astype
(
"bool"
)
return
Apply
(
self
,
...
...
@@ -101,7 +85,7 @@ class CheckAndRaise(COp):
(
out
,)
=
outputs
val
,
*
conds
=
inputs
out
[
0
]
=
val
if
not
np
.
all
(
conds
):
if
not
all
(
conds
):
raise
self
.
exc_type
(
self
.
msg
)
def
grad
(
self
,
input
,
output_gradients
):
...
...
@@ -117,28 +101,13 @@ class CheckAndRaise(COp):
)
value_name
,
*
cond_names
=
inames
out_name
=
onames
[
0
]
check
=
[]
fail_code
=
props
[
"fail"
]
param_struct_name
=
props
[
"params"
]
msg
=
self
.
msg
.
replace
(
'"'
,
'
\\
"'
)
.
replace
(
"
\n
"
,
"
\\
n"
)
for
idx
,
cond_name
in
enumerate
(
cond_names
):
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
DenseTensorType
):
check
.
append
(
f
"""
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
else
:
check
.
append
(
f
"""
if({cond_name} == 0) {{
all_conds
=
" && "
.
join
(
cond_names
)
check
=
f
"""
if(!({all_conds})) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
...
...
@@ -146,9 +115,6 @@ class CheckAndRaise(COp):
{indent(fail_code, " " * 4)}
}}
"""
)
check
=
"
\n
"
.
join
(
check
)
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
DenseTensorType
):
res
=
f
"""
...
...
@@ -162,14 +128,19 @@ class CheckAndRaise(COp):
{check}
{out_name} = {value_name};
"""
return
res
return
"
\n
"
.
join
((
check
,
res
))
def
c_code_cache_version
(
self
):
return
(
1
,
1
)
return
(
2
,
)
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
return
[
input_shapes
[
0
]]
def
do_constant_folding
(
self
,
fgraph
,
node
):
# Only constant-fold if the Assert does not fail
return
all
((
isinstance
(
c
,
Constant
)
and
bool
(
c
.
data
))
for
c
in
node
.
inputs
[
1
:])
class
Assert
(
CheckAndRaise
):
"""Implements assertion in a computational graph.
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
6cd90ee9
...
...
@@ -732,20 +732,15 @@ def is_an_upcast(type1, type2):
@register_useless
@register_specialize
@node_rewriter
(
None
)
@node_rewriter
(
[
CheckAndRaise
]
)
def
local_remove_useless_assert
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
CheckAndRaise
):
return
False
new_conds
=
[]
n_conds
=
len
(
node
.
inputs
[
1
:])
for
c
in
node
.
inputs
[
1
:]:
try
:
const
=
get_scalar_constant_value
(
c
)
if
0
!=
const
.
ndim
or
const
==
0
:
# Should we raise an error here? How to be sure it
# is not caught?
if
not
const
:
new_conds
.
append
(
c
)
except
NotScalarConstantError
:
new_conds
.
append
(
c
)
...
...
tests/tensor/rewriting/test_basic.py
浏览文件 @
6cd90ee9
...
...
@@ -487,8 +487,8 @@ class TestUselessCheckAndRaise:
def
test_local_remove_useless_2
(
self
):
"""Remove `CheckAndRaise` conditions that are always true."""
x
=
scalar
()
y
=
scalar
(
)
x
=
scalar
(
"x"
)
y
=
ps
.
bool
(
"y"
)
fg
=
FunctionGraph
(
outputs
=
[
assert_op
(
x
,
y
,
1
)],
clone
=
False
)
fg_res
=
rewrite_graph
(
fg
,
include
=
[
"canonicalize"
,
"specialize"
])
topo
=
fg_res
.
toposort
()
...
...
@@ -497,8 +497,8 @@ class TestUselessCheckAndRaise:
def
test_local_remove_useless_3
(
self
):
"""Don't remove `CheckAndRaise` conditions that are always false."""
x
=
scalar
()
y
=
scalar
(
)
x
=
scalar
(
"x"
)
y
=
ps
.
bool
(
"y"
)
fg
=
FunctionGraph
(
outputs
=
[
assert_op
(
x
,
y
,
0
)],
clone
=
False
)
fg_res
=
rewrite_graph
(
fg
,
include
=
[
"canonicalize"
,
"specialize"
])
topo
=
fg_res
.
toposort
()
...
...
@@ -1559,7 +1559,7 @@ def test_local_merge_alloc():
output
=
pt
.
alloc
(
pt
.
alloc
(
m
,
y
,
1
,
1
),
x
,
y2
,
z
,
w
)
f
=
function
([
m
,
x
,
y
,
y2
,
z
,
w
],
output
,
mode
=
rewrite_mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
3
assert
len
(
topo
)
==
4
assert
isinstance
(
topo
[
-
2
]
.
op
,
Assert
)
assert
isinstance
(
topo
[
-
1
]
.
op
,
Alloc
)
o
=
f
(
0.0
,
1
,
2
,
2
,
3
,
4
)
...
...
@@ -1616,7 +1616,7 @@ def test_local_useless_alloc():
useless_alloc
.
rewrite
(
g
)
topo
=
g
.
toposort
()
assert
len
(
topo
)
==
3
assert
len
(
topo
)
==
4
assert
isinstance
(
topo
[
-
2
]
.
op
,
Assert
)
assert
isinstance
(
topo
[
-
1
]
.
op
,
Alloc
)
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
6cd90ee9
...
...
@@ -932,7 +932,7 @@ class TestFusion:
),
(
fx
,),
(
fxv
,),
4
,
5
,
(
np
.
zeros_like
(
fxv
),),
(
"float32"
,),
),
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
6cd90ee9
...
...
@@ -8,6 +8,7 @@ from pytensor import function
from
pytensor
import
tensor
as
pt
from
pytensor.compile.mode
import
Mode
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
rewrite_graph
from
pytensor.graph.basic
import
Constant
,
applys_between
,
equal_computations
from
pytensor.npy_2_compat
import
old_np_unique
from
pytensor.raise_op
import
Assert
...
...
@@ -1252,11 +1253,17 @@ def test_broadcast_shape_symbolic_one_symbolic():
]
res_shape
=
broadcast_shape
(
*
index_shapes
,
arrays_are_shapes
=
True
)
from
pytensor.graph.rewriting.utils
import
rewrite_graph
res_shape
=
rewrite_graph
(
res_shape
)
assert
res_shape
[
0
]
.
data
==
1
assert
res_shape
[
1
]
.
data
==
1
with
pytest
.
raises
(
AssertionError
,
match
=
"Could not broadcast dimensions"
):
# broadcast_shape doesn't treat int_div as a constant 1
res_shape
[
2
]
.
eval
()
res_shape
=
broadcast_shape
(
*
index_shapes
,
arrays_are_shapes
=
True
,
allow_runtime_broadcast
=
True
)
res_shape
=
rewrite_graph
(
res_shape
)
assert
res_shape
[
0
]
.
data
==
1
assert
res_shape
[
1
]
.
data
==
1
assert
res_shape
[
2
]
.
data
==
3
...
...
tests/test_raise_op.py
浏览文件 @
6cd90ee9
...
...
@@ -82,19 +82,26 @@ def test_CheckAndRaise_basic_c(linker):
with
pytest
.
raises
(
CustomException
,
match
=
exc_msg
):
y_fn
(
0
)
assert
y_fn
(
1
)
==
1.0
x
=
pt
.
vector
()
x_val
=
np
.
array
([
1.0
],
dtype
=
pytensor
.
config
.
floatX
)
y
=
check_and_raise
(
x
,
conds
)
y_fn
=
pytensor
.
function
([
conds
,
x
],
y
.
shape
,
mode
=
Mode
(
linker
,
OPT_FAST_RUN
))
y_fn
=
pytensor
.
function
([
conds
,
x
],
y
,
mode
=
Mode
(
linker
,
OPT_FAST_RUN
))
with
pytest
.
raises
(
CustomException
,
match
=
exc_msg
):
y_fn
(
0
,
x_val
)
assert
np
.
array_equal
(
y_fn
(
1
,
x_val
),
x_val
)
x_val
=
np
.
array
([
1.0
],
dtype
=
pytensor
.
config
.
floatX
)
y_fn
=
pytensor
.
function
([
conds
,
x
],
y
.
shape
,
mode
=
Mode
(
linker
,
OPT_FAST_RUN
))
# The shape doesn't depend on y so the Assert is dropped from the graph
assert
np
.
array_equal
(
y_fn
(
0
,
x_val
),
x_val
)
y
=
check_and_raise
(
x
,
pt
.
as_tensor
(
0
))
y_grad
=
pytensor
.
grad
(
y
.
sum
(),
[
x
]
)
y_grad
=
pytensor
.
grad
(
y
.
sum
(),
x
)
y_fn
=
pytensor
.
function
([
x
],
y_grad
,
mode
=
Mode
(
linker
,
OPT_FAST_RUN
))
assert
np
.
array_equal
(
y_fn
(
x_val
),
[
x_val
]
)
# The gradient doesn't depend on y, just it's shape so the Assert is dropped from the graph
assert
np
.
array_equal
(
y_fn
(
x_val
),
x_val
)
@pytest.mark.parametrize
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论