Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7300a687
提交
7300a687
authored
11月 17, 2024
作者:
Ian Schweer
提交者:
Ricardo Vieira
11月 25, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Track generated torch files for torch compiler
上级
4b41e092
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
110 行增加
和
14 行删除
+110
-14
basic.py
pytensor/link/pytorch/dispatch/basic.py
+5
-6
blockwise.py
pytensor/link/pytorch/dispatch/blockwise.py
+3
-3
linker.py
pytensor/link/pytorch/linker.py
+62
-2
utils.py
pytensor/link/utils.py
+8
-1
test_basic.py
tests/link/pytorch/test_basic.py
+32
-1
test_blockwise.py
tests/link/pytorch/test_blockwise.py
+0
-1
没有找到文件。
pytensor/link/pytorch/dispatch/basic.py
浏览文件 @
7300a687
...
...
@@ -54,14 +54,16 @@ def pytorch_funcify_FunctionGraph(
fgraph
,
node
=
None
,
fgraph_name
=
"pytorch_funcified_fgraph"
,
conversion_func
=
pytorch_funcify
,
**
kwargs
,
):
built_kwargs
=
{
"conversion_func"
:
conversion_func
,
**
kwargs
}
return
fgraph_to_python
(
fgraph
,
pytorch_funcify
,
conversion_func
,
type_conversion_fn
=
pytorch_typify
,
fgraph_name
=
fgraph_name
,
**
kwargs
,
**
built_
kwargs
,
)
...
...
@@ -173,11 +175,8 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
# Apply inner rewrites
PYTORCH
.
optimizer
(
op
.
fgraph
)
fgraph_fn
=
pytorch_funcify
(
op
.
fgraph
,
**
kwargs
,
squeeze_output
=
True
)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return
torch
.
compiler
.
disable
(
fgraph_fn
,
recursive
=
False
)
return
fgraph_fn
@pytorch_funcify.register
(
TensorFromScalar
)
...
...
pytensor/link/pytorch/dispatch/blockwise.py
浏览文件 @
7300a687
import
torch
import
torch.compiler
from
pytensor.graph
import
FunctionGraph
from
pytensor.link.pytorch.dispatch
import
pytorch_funcify
...
...
@@ -11,12 +10,13 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
batched_dims
=
op
.
batch_ndim
(
node
)
core_node
=
op
.
_create_dummy_core_node
(
node
.
inputs
)
core_fgraph
=
FunctionGraph
(
inputs
=
core_node
.
inputs
,
outputs
=
core_node
.
outputs
)
inner_func
=
pytorch_funcify
(
core_fgraph
,
squeeze_output
=
len
(
node
.
outputs
)
==
1
)
inner_func
=
pytorch_funcify
(
core_fgraph
,
squeeze_output
=
len
(
node
.
outputs
)
==
1
,
**
kwargs
)
for
_
in
range
(
batched_dims
):
inner_func
=
torch
.
vmap
(
inner_func
)
@torch.compiler.disable
(
recursive
=
False
)
def
batcher
(
*
inputs
):
op
.
_check_runtime_broadcast
(
node
,
inputs
)
# broadcast on batched_dims
...
...
pytensor/link/pytorch/linker.py
浏览文件 @
7300a687
import
copy
from
typing
import
Any
from
pytensor.graph.basic
import
Variable
from
pytensor.link.basic
import
JITLinker
from
pytensor.link.utils
import
unique_name_generator
class
PytorchLinker
(
JITLinker
):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
gen_functors
=
[]
def
input_filter
(
self
,
inp
:
Any
)
->
Any
:
from
pytensor.link.pytorch.dispatch
import
pytorch_typify
...
...
@@ -18,14 +24,68 @@ class PytorchLinker(JITLinker):
def
fgraph_convert
(
self
,
fgraph
,
input_storage
,
storage_map
,
**
kwargs
):
from
pytensor.link.pytorch.dispatch
import
pytorch_funcify
# We want to have globally unique names
# across the entire pytensor graph, not
# just the subgraph
generator
=
unique_name_generator
([
"torch_linker"
])
# Ensure that torch is aware of the generated
# code so we can compile without graph breaks
def
conversion_func_register
(
*
args
,
**
kwargs
):
functor
=
pytorch_funcify
(
*
args
,
**
kwargs
)
name
=
kwargs
[
"unique_name"
](
functor
)
self
.
gen_functors
.
append
((
f
"_{name}"
,
functor
))
return
functor
built_kwargs
=
{
"unique_name"
:
generator
,
"conversion_func"
:
conversion_func_register
,
**
kwargs
,
}
return
pytorch_funcify
(
fgraph
,
input_storage
=
input_storage
,
storage_map
=
storage_map
,
**
kwargs
fgraph
,
input_storage
=
input_storage
,
storage_map
=
storage_map
,
**
built_
kwargs
)
def
jit_compile
(
self
,
fn
):
import
torch
return
torch
.
compile
(
fn
)
class
wrapper
:
"""
Pytorch would fail compiling our method when trying
to resolve some of the methods returned from dispatch
calls. We want to be careful to not leak the methods,
so this class just holds them and provisions the expected
location accordingly
https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
"""
def
__init__
(
self
,
fn
,
gen_functors
):
self
.
fn
=
torch
.
compile
(
fn
)
self
.
gen_functors
=
copy
.
copy
(
gen_functors
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
import
pytensor.link.utils
# set attrs
for
n
,
fn
in
self
.
gen_functors
:
setattr
(
pytensor
.
link
.
utils
,
n
[
1
:],
fn
)
res
=
self
.
fn
(
*
args
,
**
kwargs
)
# unset attrs
for
n
,
_
in
self
.
gen_functors
:
if
getattr
(
pytensor
.
link
.
utils
,
n
[
1
:],
False
):
delattr
(
pytensor
.
link
.
utils
,
n
[
1
:])
return
res
def
__del__
(
self
):
del
self
.
gen_functors
res
=
wrapper
(
fn
,
self
.
gen_functors
)
self
.
gen_functors
=
[]
return
res
def
create_thunk_inputs
(
self
,
storage_map
):
thunk_inputs
=
[]
...
...
pytensor/link/utils.py
浏览文件 @
7300a687
...
...
@@ -675,6 +675,7 @@ def fgraph_to_python(
local_env
:
dict
[
Any
,
Any
]
|
None
=
None
,
get_name_for_object
:
Callable
[[
Any
],
str
]
=
get_name_for_object
,
squeeze_output
:
bool
=
False
,
unique_name
:
Callable
|
None
=
None
,
**
kwargs
,
)
->
Callable
:
"""Convert a `FunctionGraph` into a regular Python function.
...
...
@@ -706,6 +707,8 @@ def fgraph_to_python(
get_name_for_object
A function used to provide names for the objects referenced within the
generated function.
unique_name
A function to make random function names for generated code
squeeze_output
If the `FunctionGraph` has only one output and this option is
``True``, return the single output instead of a tuple with the output.
...
...
@@ -719,7 +722,11 @@ def fgraph_to_python(
if
storage_map
is
None
:
storage_map
=
{}
unique_name
=
unique_name_generator
([
fgraph_name
])
if
not
unique_name
:
unique_name
=
unique_name_generator
([
fgraph_name
])
# make sure we plumb this through
kwargs
[
"unique_name"
]
=
unique_name
if
global_env
is
None
:
global_env
=
{}
...
...
tests/link/pytorch/test_basic.py
浏览文件 @
7300a687
...
...
@@ -22,6 +22,7 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector
torch
=
pytest
.
importorskip
(
"torch"
)
torch_dispatch
=
pytest
.
importorskip
(
"pytensor.link.pytorch.dispatch.basic"
)
optimizer
=
RewriteDatabaseQuery
(
...
...
@@ -335,7 +336,7 @@ def test_pytorch_OpFromGraph():
ofg_2
=
OpFromGraph
([
x
,
y
],
[
x
*
y
,
x
-
y
])
o1
,
o2
=
ofg_2
(
y
,
z
)
out
=
ofg_1
(
x
,
o1
)
+
o2
out
=
ofg_1
(
x
,
o1
)
/
o2
xv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
...
...
@@ -343,3 +344,33 @@ def test_pytorch_OpFromGraph():
f
=
FunctionGraph
([
x
,
y
,
z
],
[
out
])
compare_pytorch_and_py
(
f
,
[
xv
,
yv
,
zv
])
def
test_pytorch_link_references
():
import
pytensor.link.utils
as
m
class
BasicOp
(
Op
):
def
__init__
(
self
):
super
()
.
__init__
()
def
make_node
(
self
,
*
x
):
return
Apply
(
self
,
list
(
x
),
[
xi
.
type
()
for
xi
in
x
])
def
perform
(
self
,
*
_
):
raise
RuntimeError
(
"In perform"
)
@torch_dispatch.pytorch_funcify.register
(
BasicOp
)
def
fn
(
op
,
node
,
**
kwargs
):
def
inner_fn
(
x
):
assert
"inner_fn"
in
dir
(
m
),
"not available during dispatch"
return
x
return
inner_fn
x
=
vector
(
"x"
)
op
=
BasicOp
()
out
=
op
(
x
)
f
=
function
([
x
],
out
,
mode
=
"PYTORCH"
)
f
(
torch
.
ones
(
3
))
assert
"inner_fn"
not
in
dir
(
m
),
"function call reference leaked"
tests/link/pytorch/test_blockwise.py
浏览文件 @
7300a687
...
...
@@ -29,7 +29,6 @@ class TestOp(Op):
@basic.pytorch_funcify.register
(
TestOp
)
def
evaluate_test_op
(
op
,
**
_
):
@torch.compiler.disable
(
recursive
=
False
)
def
func
(
a
,
b
):
op
.
call_shapes
.
extend
(
map
(
torch
.
Tensor
.
size
,
[
a
,
b
]))
return
a
@
b
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论