Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4134881f
提交
4134881f
authored
7月 09, 2024
作者:
HarshvirSandhu
提交者:
Ricardo Vieira
9月 01, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement indexing operations in pytorch
Co-authored-by:
Ricardo Vieira
<
28983449+ricardov94@users.noreply.github.com
>
上级
1a1c62bb
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
345 行增加
和
9 行删除
+345
-9
mode.py
pytensor/compile/mode.py
+1
-0
__init__.py
pytensor/link/pytorch/dispatch/__init__.py
+2
-1
basic.py
pytensor/link/pytorch/dispatch/basic.py
+29
-5
subtensor.py
pytensor/link/pytorch/dispatch/subtensor.py
+124
-0
test_basic.py
tests/link/pytorch/test_basic.py
+3
-3
test_subtensor.py
tests/link/pytorch/test_subtensor.py
+186
-0
没有找到文件。
pytensor/compile/mode.py
浏览文件 @
4134881f
...
...
@@ -471,6 +471,7 @@ PYTORCH = Mode(
"BlasOpt"
,
"fusion"
,
"inplace"
,
"local_uint_constant_indices"
,
],
),
)
...
...
pytensor/link/pytorch/dispatch/__init__.py
浏览文件 @
4134881f
...
...
@@ -7,7 +7,8 @@ import pytensor.link.pytorch.dispatch.scalar
import
pytensor.link.pytorch.dispatch.elemwise
import
pytensor.link.pytorch.dispatch.math
import
pytensor.link.pytorch.dispatch.extra_ops
import
pytensor.link.pytorch.dispatch.nlinalg
import
pytensor.link.pytorch.dispatch.shape
import
pytensor.link.pytorch.dispatch.sort
import
pytensor.link.pytorch.dispatch.
nlinalg
import
pytensor.link.pytorch.dispatch.
subtensor
# isort: on
pytensor/link/pytorch/dispatch/basic.py
浏览文件 @
4134881f
from
functools
import
singledispatch
from
types
import
NoneType
import
numpy
as
np
import
torch
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.link.utils
import
fgraph_to_python
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.tensor.basic
import
Alloc
,
AllocEmpty
,
ARange
,
Eye
,
Join
,
MakeVector
from
pytensor.tensor.basic
import
(
Alloc
,
AllocEmpty
,
ARange
,
Eye
,
Join
,
MakeVector
,
TensorFromScalar
,
)
@singledispatch
def
pytorch_typify
(
data
,
dtype
=
None
,
**
kwargs
):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
def
pytorch_typify
(
data
,
**
kwargs
):
raise
NotImplementedError
(
f
"pytorch_typify is not implemented for {type(data)}"
)
@pytorch_typify.register
(
np
.
ndarray
)
@pytorch_typify.register
(
torch
.
Tensor
)
def
pytorch_typify_tensor
(
data
,
dtype
=
None
,
**
kwargs
):
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
)
@pytorch_typify.register
(
slice
)
@pytorch_typify.register
(
NoneType
)
def
pytorch_typify_None
(
data
,
**
kwargs
):
return
None
@pytorch_typify.register
(
np
.
number
)
def
pytorch_typify_no_conversion_needed
(
data
,
**
kwargs
):
return
data
@singledispatch
...
...
@@ -132,3 +148,11 @@ def pytorch_funcify_MakeVector(op, **kwargs):
return
torch
.
tensor
(
x
,
dtype
=
torch_dtype
)
return
makevector
@pytorch_funcify.register
(
TensorFromScalar
)
def
pytorch_funcify_TensorFromScalar
(
op
,
**
kwargs
):
def
tensorfromscalar
(
x
):
return
torch
.
as_tensor
(
x
)
return
tensorfromscalar
pytensor/link/pytorch/dispatch/subtensor.py
0 → 100644
浏览文件 @
4134881f
from
pytensor.link.pytorch.dispatch.basic
import
pytorch_funcify
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
IncSubtensor
,
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceType
def
check_negative_steps
(
indices
):
for
index
in
indices
:
if
isinstance
(
index
,
slice
):
if
index
.
step
is
not
None
and
index
.
step
<
0
:
raise
NotImplementedError
(
"Negative step sizes are not supported in Pytorch"
)
@pytorch_funcify.register
(
Subtensor
)
def
pytorch_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
op
.
idx_list
def
subtensor
(
x
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
check_negative_steps
(
indices
)
return
x
[
indices
]
return
subtensor
@pytorch_funcify.register
(
MakeSlice
)
def
pytorch_funcify_makeslice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
x
)
return
makeslice
@pytorch_funcify.register
(
AdvancedSubtensor1
)
@pytorch_funcify.register
(
AdvancedSubtensor
)
def
pytorch_funcify_AdvSubtensor
(
op
,
node
,
**
kwargs
):
def
advsubtensor
(
x
,
*
indices
):
check_negative_steps
(
indices
)
return
x
[
indices
]
return
advsubtensor
@pytorch_funcify.register
(
IncSubtensor
)
def
pytorch_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
op
.
idx_list
inplace
=
op
.
inplace
if
op
.
set_instead_of_inc
:
def
set_subtensor
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
check_negative_steps
(
indices
)
if
not
inplace
:
x
=
x
.
clone
()
x
[
indices
]
=
y
return
x
return
set_subtensor
else
:
def
inc_subtensor
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
check_negative_steps
(
indices
)
if
not
inplace
:
x
=
x
.
clone
()
x
[
indices
]
+=
y
return
x
return
inc_subtensor
@pytorch_funcify.register
(
AdvancedIncSubtensor
)
@pytorch_funcify.register
(
AdvancedIncSubtensor1
)
def
pytorch_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
ignore_duplicates
=
getattr
(
op
,
"ignore_duplicates"
,
False
)
if
op
.
set_instead_of_inc
:
def
adv_set_subtensor
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
not
inplace
:
x
=
x
.
clone
()
x
[
indices
]
=
y
.
type_as
(
x
)
return
x
return
adv_set_subtensor
elif
ignore_duplicates
:
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
not
inplace
:
x
=
x
.
clone
()
x
[
indices
]
+=
y
.
type_as
(
x
)
return
x
return
adv_inc_subtensor_no_duplicates
else
:
if
any
(
isinstance
(
idx
.
type
,
SliceType
)
for
idx
in
node
.
inputs
[
2
:]):
raise
NotImplementedError
(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def
adv_inc_subtensor
(
x
,
y
,
*
indices
):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if
not
inplace
:
x
=
x
.
clone
()
x
.
index_put_
(
indices
,
y
.
type_as
(
x
),
accumulate
=
True
)
return
x
return
adv_inc_subtensor
tests/link/pytorch/test_basic.py
浏览文件 @
4134881f
...
...
@@ -66,10 +66,10 @@ def compare_pytorch_and_py(
py_res
=
pytensor_py_fn
(
*
test_inputs
)
if
len
(
fgraph
.
outputs
)
>
1
:
for
j
,
p
in
zip
(
pytorch_res
,
py_res
):
assert_fn
(
j
.
cpu
(),
p
)
for
pytorch_res_i
,
py_res_i
in
zip
(
pytorch_res
,
py_res
):
assert_fn
(
pytorch_res_i
.
detach
()
.
cpu
()
.
numpy
(),
py_res_i
)
else
:
assert_fn
(
[
pytorch_res
[
0
]
.
cpu
()],
py_res
)
assert_fn
(
pytorch_res
[
0
]
.
detach
()
.
cpu
()
.
numpy
(),
py_res
[
0
]
)
return
pytensor_torch_fn
,
pytorch_res
...
...
tests/link/pytorch/test_subtensor.py
0 → 100644
浏览文件 @
4134881f
import
contextlib
import
numpy
as
np
import
pytest
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
inc_subtensor
,
set_subtensor
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
def
test_pytorch_Subtensor
():
shape
=
(
3
,
4
,
5
)
x_pt
=
pt
.
tensor
(
"x"
,
shape
=
shape
,
dtype
=
"int"
)
x_np
=
np
.
arange
(
np
.
prod
(
shape
))
.
reshape
(
shape
)
out_pt
=
x_pt
[
1
,
2
,
0
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[
1
:,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
# symbolic index
a_pt
=
ps
.
int64
(
"a"
)
a_np
=
1
out_pt
=
x_pt
[
a_pt
,
2
,
a_pt
:
2
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
,
a_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
,
a_np
])
with
pytest
.
raises
(
NotImplementedError
,
match
=
"Negative step sizes are not supported in Pytorch"
):
out_pt
=
x_pt
[::
-
1
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
def
test_pytorch_AdvSubtensor
():
shape
=
(
3
,
4
,
5
)
x_pt
=
pt
.
tensor
(
"x"
,
shape
=
shape
,
dtype
=
"int"
)
x_np
=
np
.
arange
(
np
.
prod
(
shape
))
.
reshape
(
shape
)
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
1
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
None
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
a_pt
=
ps
.
int64
(
"a"
)
a_np
=
2
out_pt
=
x_pt
[[
1
,
a_pt
],
a_pt
]
out_fg
=
FunctionGraph
([
x_pt
,
a_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
,
a_np
])
# boolean indices
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
a_pt
=
pt
.
tensor3
(
"a"
,
dtype
=
"bool"
)
a_np
=
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)
out_pt
=
x_pt
[
a_pt
]
out_fg
=
FunctionGraph
([
x_pt
,
a_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
,
a_np
])
with
pytest
.
raises
(
NotImplementedError
,
match
=
"Negative step sizes are not supported in Pytorch"
):
out_pt
=
x_pt
[[
1
,
2
],
::
-
1
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
@pytest.mark.parametrize
(
"subtensor_op"
,
[
set_subtensor
,
inc_subtensor
])
def
test_pytorch_IncSubtensor
(
subtensor_op
):
x_pt
=
pt
.
tensor3
(
"x"
)
x_test
=
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
config
.
floatX
)
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
out_pt
=
subtensor_op
(
x_pt
[
1
,
2
,
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Test different type update
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
"float32"
))
out_pt
=
subtensor_op
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
out_pt
=
subtensor_op
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
def
inc_subtensor_ignore_duplicates
(
x
,
y
):
return
inc_subtensor
(
x
,
y
,
ignore_duplicates
=
True
)
@pytest.mark.parametrize
(
"advsubtensor_op"
,
[
set_subtensor
,
inc_subtensor
,
inc_subtensor_ignore_duplicates
]
)
def
test_pytorch_AvdancedIncSubtensor
(
advsubtensor_op
):
rng
=
np
.
random
.
default_rng
(
42
)
x_pt
=
pt
.
tensor3
(
"x"
)
x_test
=
(
np
.
arange
(
3
*
4
*
5
)
+
1
)
.
reshape
((
3
,
4
,
5
))
.
astype
(
config
.
floatX
)
st_pt
=
pt
.
as_tensor_variable
(
rng
.
uniform
(
-
1
,
1
,
size
=
(
2
,
4
,
5
))
.
astype
(
config
.
floatX
)
)
out_pt
=
advsubtensor_op
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Repeated indices
out_pt
=
advsubtensor_op
(
x_pt
[
np
.
r_
[
0
,
0
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Mixing advanced and basic indexing
if
advsubtensor_op
is
inc_subtensor
:
# PyTorch does not support `np.add.at` equivalent with slices
expectation
=
pytest
.
raises
(
NotImplementedError
)
else
:
expectation
=
contextlib
.
nullcontext
()
st_pt
=
pt
.
as_tensor_variable
(
x_test
[[
0
,
2
],
0
,
:
3
])
out_pt
=
advsubtensor_op
(
x_pt
[[
0
,
0
],
0
,
:
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
with
expectation
:
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Test different dtype update
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
"float32"
))
out_pt
=
advsubtensor_op
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Boolean indices
out_pt
=
advsubtensor_op
(
x_pt
[
x_pt
>
5
],
1.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论