Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
aad6fb75
提交
aad6fb75
authored
10月 21, 2024
作者:
ricardoV94
提交者:
Ricardo Vieira
1月 13, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deprecate `pytensor.get_underlying_scalar_constant`
上级
a120dc27
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
36 行增加
和
25 行删除
+36
-25
__init__.py
pytensor/__init__.py
+8
-7
gradient.py
pytensor/gradient.py
+6
-3
basic.py
pytensor/tensor/basic.py
+19
-12
test_elemwise.py
tests/tensor/test_elemwise.py
+3
-3
没有找到文件。
pytensor/__init__.py
浏览文件 @
aad6fb75
...
@@ -24,6 +24,7 @@ __docformat__ = "restructuredtext en"
...
@@ -24,6 +24,7 @@ __docformat__ = "restructuredtext en"
# pytensor code, since this code may want to log some messages.
# pytensor code, since this code may want to log some messages.
import
logging
import
logging
import
sys
import
sys
import
warnings
from
functools
import
singledispatch
from
functools
import
singledispatch
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
NoReturn
,
Optional
from
typing
import
Any
,
NoReturn
,
Optional
...
@@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v):
...
@@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v):
If `v` is not some view of constant data, then raise a
If `v` is not some view of constant data, then raise a
`NotScalarConstantError`.
`NotScalarConstantError`.
"""
"""
# Is it necessary to test for presence of pytensor.sparse at runtime?
warnings
.
warn
(
sparse
=
globals
()
.
get
(
"sparse"
)
"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead."
,
if
sparse
and
isinstance
(
v
.
type
,
sparse
.
SparseTensorType
):
FutureWarning
,
if
v
.
owner
is
not
None
and
isinstance
(
v
.
owner
.
op
,
sparse
.
CSM
):
)
data
=
v
.
owner
.
inputs
[
0
]
from
pytensor.tensor.basic
import
get_underlying_scalar_constant_value
return
tensor
.
get_underlying_scalar_constant_value
(
data
)
return
tensor
.
get_underlying_scalar_constant_value
(
v
)
return
get_underlying_scalar_constant_value
(
v
)
# isort: off
# isort: off
...
...
pytensor/gradient.py
浏览文件 @
aad6fb75
...
@@ -1329,7 +1329,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
...
@@ -1329,7 +1329,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
f
" {i}. Since this input is only connected "
f
" {i}. Since this input is only connected "
"to integer-valued outputs, it should "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
"evaluate to zeros, but it evaluates to"
f
"{pytensor.get_underlying_scalar_constant(term)}."
f
"{pytensor.get_underlying_scalar_constant
_value
(term)}."
)
)
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
...
@@ -2157,6 +2157,9 @@ def _is_zero(x):
...
@@ -2157,6 +2157,9 @@ def _is_zero(x):
'maybe' means that x is an expression that is complicated enough
'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0.
that we can't tell that it simplifies to 0.
"""
"""
from
pytensor.tensor
import
get_underlying_scalar_constant_value
from
pytensor.tensor.exceptions
import
NotScalarConstantError
if
not
hasattr
(
x
,
"type"
):
if
not
hasattr
(
x
,
"type"
):
return
np
.
all
(
x
==
0.0
)
return
np
.
all
(
x
==
0.0
)
if
isinstance
(
x
.
type
,
NullType
):
if
isinstance
(
x
.
type
,
NullType
):
...
@@ -2166,9 +2169,9 @@ def _is_zero(x):
...
@@ -2166,9 +2169,9 @@ def _is_zero(x):
no_constant_value
=
True
no_constant_value
=
True
try
:
try
:
constant_value
=
pytensor
.
get_underlying_scalar_constant
(
x
)
constant_value
=
get_underlying_scalar_constant_value
(
x
)
no_constant_value
=
False
no_constant_value
=
False
except
pytensor
.
tensor
.
exceptions
.
NotScalarConstantError
:
except
NotScalarConstantError
:
pass
pass
if
no_constant_value
:
if
no_constant_value
:
...
...
pytensor/tensor/basic.py
浏览文件 @
aad6fb75
...
@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
...
@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
"""
"""
from
pytensor.compile.ops
import
DeepCopyOp
,
OutputGuard
from
pytensor.compile.ops
import
DeepCopyOp
,
OutputGuard
from
pytensor.sparse
import
CSM
from
pytensor.tensor.subtensor
import
Subtensor
v
=
orig_v
v
=
orig_v
while
True
:
while
True
:
...
@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
...
@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
raise
NotScalarConstantError
()
raise
NotScalarConstantError
()
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
op
=
v
.
owner
.
op
max_recur
-=
1
max_recur
-=
1
if
isinstance
(
if
isinstance
(
v
.
owner
.
op
,
op
,
Alloc
|
DimShuffle
|
Unbroadcast
|
OutputGuard
|
DeepCopyOp
Alloc
|
DimShuffle
|
Unbroadcast
|
OutputGuard
|
DeepCopyOp
,
):
):
# OutputGuard is only used in debugmode but we
# OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles
# keep it here to avoid problems with old pickles
v
=
v
.
owner
.
inputs
[
0
]
v
=
v
.
owner
.
inputs
[
0
]
continue
continue
elif
isinstance
(
v
.
owner
.
op
,
Shape_i
):
elif
isinstance
(
op
,
Shape_i
):
i
=
v
.
owner
.
op
.
i
i
=
v
.
owner
.
op
.
i
inp
=
v
.
owner
.
inputs
[
0
]
inp
=
v
.
owner
.
inputs
[
0
]
if
isinstance
(
inp
,
Constant
):
if
isinstance
(
inp
,
Constant
):
...
@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
...
@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
# mess with the stabilization optimization and be too slow.
# mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice()
# We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly.
# to allow it to determine the broadcast pattern correctly.
elif
isinstance
(
v
.
owner
.
op
,
ScalarFromTensor
|
TensorFromScalar
):
elif
isinstance
(
op
,
ScalarFromTensor
|
TensorFromScalar
):
v
=
v
.
owner
.
inputs
[
0
]
v
=
v
.
owner
.
inputs
[
0
]
continue
continue
elif
isinstance
(
v
.
owner
.
op
,
CheckAndRaise
):
elif
isinstance
(
op
,
CheckAndRaise
):
# check if all conditions are constant and true
# check if all conditions are constant and true
conds
=
[
conds
=
[
get_underlying_scalar_constant_value
(
c
,
max_recur
=
max_recur
)
get_underlying_scalar_constant_value
(
c
,
max_recur
=
max_recur
)
...
@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
...
@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
if
builtins
.
all
(
0
==
c
.
ndim
and
c
!=
0
for
c
in
conds
):
if
builtins
.
all
(
0
==
c
.
ndim
and
c
!=
0
for
c
in
conds
):
v
=
v
.
owner
.
inputs
[
0
]
v
=
v
.
owner
.
inputs
[
0
]
continue
continue
elif
isinstance
(
v
.
owner
.
op
,
ps
.
ScalarOp
):
elif
isinstance
(
op
,
ps
.
ScalarOp
):
if
isinstance
(
v
.
owner
.
op
,
ps
.
Second
):
if
isinstance
(
v
.
owner
.
op
,
ps
.
Second
):
# We don't need both input to be constant for second
# We don't need both input to be constant for second
shp
,
val
=
v
.
owner
.
inputs
shp
,
val
=
v
.
owner
.
inputs
...
@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
...
@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
# In fast_compile, we don't enable local_fill_to_alloc, so
# In fast_compile, we don't enable local_fill_to_alloc, so
# we need to investigate Second as Alloc. So elemwise
# we need to investigate Second as Alloc. So elemwise
# don't disable the check for Second.
# don't disable the check for Second.
elif
isinstance
(
v
.
owner
.
op
,
Elemwise
):
elif
isinstance
(
op
,
Elemwise
):
if
isinstance
(
v
.
owner
.
op
.
scalar_op
,
ps
.
Second
):
if
isinstance
(
v
.
owner
.
op
.
scalar_op
,
ps
.
Second
):
# We don't need both input to be constant for second
# We don't need both input to be constant for second
shp
,
val
=
v
.
owner
.
inputs
shp
,
val
=
v
.
owner
.
inputs
...
@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
...
@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
ret
=
[[
None
]]
ret
=
[[
None
]]
v
.
owner
.
op
.
perform
(
v
.
owner
,
const
,
ret
)
v
.
owner
.
op
.
perform
(
v
.
owner
,
const
,
ret
)
return
np
.
asarray
(
ret
[
0
][
0
]
.
copy
())
return
np
.
asarray
(
ret
[
0
][
0
]
.
copy
())
elif
(
elif
isinstance
(
op
,
Subtensor
)
and
v
.
ndim
==
0
:
isinstance
(
v
.
owner
.
op
,
pytensor
.
tensor
.
subtensor
.
Subtensor
)
and
v
.
ndim
==
0
):
if
isinstance
(
v
.
owner
.
inputs
[
0
],
TensorConstant
):
if
isinstance
(
v
.
owner
.
inputs
[
0
],
TensorConstant
):
from
pytensor.tensor.subtensor
import
get_constant_idx
from
pytensor.tensor.subtensor
import
get_constant_idx
...
@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
...
@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
if
isinstance
(
grandparent
,
Constant
):
if
isinstance
(
grandparent
,
Constant
):
return
np
.
asarray
(
np
.
shape
(
grandparent
.
data
)[
idx
])
return
np
.
asarray
(
np
.
shape
(
grandparent
.
data
)[
idx
])
elif
isinstance
(
op
,
CSM
):
data
=
get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
,
elemwise
=
elemwise
,
max_recur
=
max_recur
)
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
if
data
==
0
:
return
data
break
raise
NotScalarConstantError
()
raise
NotScalarConstantError
()
...
@@ -4071,7 +4078,7 @@ class Choose(Op):
...
@@ -4071,7 +4078,7 @@ class Choose(Op):
static_out_shape
=
()
static_out_shape
=
()
for
s
in
out_shape
:
for
s
in
out_shape
:
try
:
try
:
s_val
=
pytensor
.
get_underlying_scalar_constant
(
s
)
s_val
=
get_underlying_scalar_constant_value
(
s
)
except
(
NotScalarConstantError
,
AttributeError
):
except
(
NotScalarConstantError
,
AttributeError
):
s_val
=
None
s_val
=
None
...
...
tests/tensor/test_elemwise.py
浏览文件 @
aad6fb75
...
@@ -19,7 +19,7 @@ from pytensor.graph.replace import vectorize_node
...
@@ -19,7 +19,7 @@ from pytensor.graph.replace import vectorize_node
from
pytensor.link.basic
import
PerformLinker
from
pytensor.link.basic
import
PerformLinker
from
pytensor.link.c.basic
import
CLinker
,
OpWiseCLinker
from
pytensor.link.c.basic
import
CLinker
,
OpWiseCLinker
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor.basic
import
second
from
pytensor.tensor.basic
import
get_scalar_constant_value
,
second
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.math
import
Any
,
Sum
,
exp
from
pytensor.tensor.math
import
Any
,
Sum
,
exp
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
...
@@ -807,8 +807,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
...
@@ -807,8 +807,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert
len
(
res_shape
)
==
1
assert
len
(
res_shape
)
==
1
assert
len
(
res_shape
[
0
])
==
2
assert
len
(
res_shape
[
0
])
==
2
assert
pytensor
.
get_underlying_scalar_constant
(
res_shape
[
0
][
0
])
==
1
assert
get_scalar_constant_value
(
res_shape
[
0
][
0
])
==
1
assert
pytensor
.
get_underlying_scalar_constant
(
res_shape
[
0
][
1
])
==
1
assert
get_scalar_constant_value
(
res_shape
[
0
][
1
])
==
1
def
test_infer_shape_multi_output
(
self
):
def
test_infer_shape_multi_output
(
self
):
class
CustomElemwise
(
Elemwise
):
class
CustomElemwise
(
Elemwise
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论