Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2307d877
提交
2307d877
authored
6月 24, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 24, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused or unreachable code in tensor/sort.py
上级
71592152
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
9 行增加
和
105 行删除
+9
-105
sort.py
pytensor/tensor/sort.py
+9
-105
没有找到文件。
pytensor/tensor/sort.py
浏览文件 @
2307d877
import
numpy
as
np
from
pytensor.gradient
import
grad_undefined
from
pytensor.graph.basic
import
Apply
,
Constant
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.tensor.basic
import
arange
,
as_tensor_variable
,
switch
from
pytensor.tensor.math
import
eq
,
ge
,
mul
from
pytensor.tensor.math
import
eq
,
ge
from
pytensor.tensor.type
import
TensorType
def
_variable_is_none
(
var
):
return
isinstance
(
var
,
Constant
)
and
var
.
data
is
None
def
_check_tensor_is_scalar
(
var
):
"""
Checks if a tensor variable is scalar, raise ValueError otherwise
"""
msg
=
"
%(var)
s is expected to be 0d tensor, got
%(ndim)
d"
if
var
.
ndim
!=
0
:
raise
ValueError
(
msg
%
(
var
,
var
.
ndim
))
class
SortOp
(
Op
):
"""
This class is a wrapper for numpy sort function.
...
...
@@ -39,28 +26,16 @@ class SortOp(Op):
def
make_node
(
self
,
input
,
axis
=-
1
):
input
=
as_tensor_variable
(
input
)
axis
=
as_tensor_variable
(
axis
)
axis
=
as_tensor_variable
(
axis
,
ndim
=
0
,
dtype
=
int
)
out_type
=
input
.
type
()
return
Apply
(
self
,
[
input
,
axis
],
[
out_type
])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
a
=
inputs
[
0
]
axis
=
inputs
[
1
]
if
axis
is
not
None
:
if
axis
!=
int
(
axis
):
raise
ValueError
(
"sort axis must be an integer or None"
)
axis
=
int
(
axis
)
a
,
axis
=
inputs
z
=
output_storage
[
0
]
z
[
0
]
=
np
.
sort
(
a
,
axis
,
self
.
kind
,
self
.
order
)
z
[
0
]
=
np
.
sort
(
a
,
int
(
axis
)
,
self
.
kind
,
self
.
order
)
def
infer_shape
(
self
,
fgraph
,
node
,
inputs_shapes
):
if
_variable_is_none
(
node
.
inputs
[
1
]):
# That means axis = None,
# So the array is flattened before being sorted
return
[(
mul
(
*
inputs_shapes
[
0
]),)]
# axis should not be None
# So there should be the same number of dimensions
# in the input and output
assert
node
.
inputs
[
0
]
.
ndim
==
node
.
outputs
[
0
]
.
ndim
assert
inputs_shapes
[
1
]
==
()
return
[
inputs_shapes
[
0
]]
...
...
@@ -172,7 +147,7 @@ class ArgSortOp(Op):
def
make_node
(
self
,
input
,
axis
=-
1
):
input
=
as_tensor_variable
(
input
)
axis
=
as_tensor_variable
(
axis
)
axis
=
as_tensor_variable
(
axis
,
ndim
=
0
,
dtype
=
int
)
return
Apply
(
self
,
[
input
,
axis
],
...
...
@@ -180,22 +155,14 @@ class ArgSortOp(Op):
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
a
=
inputs
[
0
]
axis
=
inputs
[
1
]
if
axis
is
not
None
:
if
axis
!=
int
(
axis
):
raise
ValueError
(
"sort axis must be an integer or None"
)
axis
=
int
(
axis
)
a
,
axis
=
inputs
z
=
output_storage
[
0
]
z
[
0
]
=
_asarray
(
np
.
argsort
(
a
,
axis
,
self
.
kind
,
self
.
order
),
dtype
=
node
.
outputs
[
0
]
.
dtype
np
.
argsort
(
a
,
int
(
axis
),
self
.
kind
,
self
.
order
),
dtype
=
node
.
outputs
[
0
]
.
dtype
,
)
def
infer_shape
(
self
,
fgraph
,
node
,
inputs_shapes
):
if
_variable_is_none
(
node
.
inputs
[
1
]):
return
[(
mul
(
*
inputs_shapes
[
0
]),)]
# axis should not be None, so there should be the same number of
# dimensions in the input and output
assert
node
.
inputs
[
0
]
.
ndim
==
node
.
outputs
[
0
]
.
ndim
assert
inputs_shapes
[
1
]
==
()
return
[
inputs_shapes
[
0
]]
...
...
@@ -239,66 +206,3 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
a
=
a
.
flatten
()
axis
=
0
return
ArgSortOp
(
kind
,
order
)(
a
,
axis
)
def
_topk_py_impl
(
op
,
x
,
k
,
axis
,
idx_dtype
):
ndim
=
x
.
ndim
assert
-
ndim
<=
axis
<
ndim
axis
%=
ndim
if
k
==
0
:
raise
ValueError
(
"topk: kth cannot be zero"
)
elif
k
>
x
.
shape
[
axis
]:
raise
ValueError
(
f
"topk: kth cannot be larger than the size of specified axis {int(axis)}"
)
if
abs
(
k
)
==
1
:
# negative k means min instead of max
fn_max
=
[
None
,
np
.
max
,
np
.
min
][
k
]
fn_argmax
=
[
None
,
np
.
argmax
,
np
.
argmin
][
k
]
if
not
op
.
return_indices
:
return
np
.
expand_dims
(
fn_max
(
x
,
axis
=
axis
),
axis
)
elif
op
.
return_values
:
zi
=
np
.
expand_dims
(
fn_argmax
(
x
,
axis
=
axis
),
axis
)
idx2
=
tuple
(
np
.
arange
(
s
)
.
reshape
((
s
,)
+
(
1
,)
*
(
ndim
-
i
-
1
))
if
i
!=
axis
else
zi
for
i
,
s
in
enumerate
(
x
.
shape
)
)
zv
=
x
[
idx2
]
return
zv
,
zi
.
astype
(
idx_dtype
)
else
:
zi
=
np
.
expand_dims
(
fn_argmax
(
x
,
axis
=
axis
),
axis
)
return
zi
.
astype
(
idx_dtype
)
if
x
.
shape
[
axis
]
==
abs
(
k
):
if
not
op
.
return_indices
:
return
x
.
copy
()
else
:
l
=
axis
r
=
ndim
-
l
reps
=
list
(
x
.
shape
)
reps
[
axis
]
=
1
zi
=
np
.
arange
(
abs
(
k
),
dtype
=
idx_dtype
)
zi
=
zi
.
reshape
((
1
,)
*
l
+
(
k
,)
+
(
1
,)
*
(
r
-
1
))
zi
=
np
.
tile
(
zi
,
reps
)
if
op
.
return_values
:
return
x
.
copy
(),
zi
else
:
return
zi
idx
=
[
slice
(
None
)]
*
ndim
idx
[
axis
]
=
slice
(
-
k
,
None
)
if
k
>
0
else
slice
(
-
k
)
if
not
op
.
return_indices
:
zv
=
np
.
partition
(
x
,
-
k
,
axis
=
axis
)[
tuple
(
idx
)]
return
zv
elif
op
.
return_values
:
zi
=
np
.
argpartition
(
x
,
-
k
,
axis
=
axis
)[
tuple
(
idx
)]
idx2
=
tuple
(
np
.
arange
(
s
)
.
reshape
((
s
,)
+
(
1
,)
*
(
ndim
-
i
-
1
))
if
i
!=
axis
else
zi
for
i
,
s
in
enumerate
(
x
.
shape
)
)
zv
=
x
[
idx2
]
return
zv
,
zi
.
astype
(
idx_dtype
)
else
:
zi
=
np
.
argpartition
(
x
,
-
k
,
axis
=
axis
)[
tuple
(
idx
)]
return
zi
.
astype
(
idx_dtype
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论