Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a0c64b5f
提交
a0c64b5f
authored
11月 21, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reduce overhead of JITLinker
上级
d1c5ae27
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
50 行增加
和
89 行删除
+50
-89
basic.py
pytensor/link/basic.py
+10
-42
linker.py
pytensor/link/numba/linker.py
+0
-17
linker.py
pytensor/link/pytorch/linker.py
+11
-15
test_basic.py
tests/link/numba/test_basic.py
+17
-0
test_basic.py
tests/link/pytorch/test_basic.py
+11
-14
test_elemwise.py
tests/link/pytorch/test_elemwise.py
+1
-1
没有找到文件。
pytensor/link/basic.py
浏览文件 @
a0c64b5f
...
@@ -653,41 +653,36 @@ class JITLinker(PerformLinker):
...
@@ -653,41 +653,36 @@ class JITLinker(PerformLinker):
)
)
thunk_inputs
=
self
.
create_thunk_inputs
(
storage_map
)
thunk_inputs
=
self
.
create_thunk_inputs
(
storage_map
)
thunks
=
[]
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
self
.
fgraph
.
outputs
]
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
self
.
fgraph
.
outputs
]
fgraph_jit
=
self
.
jit_compile
(
converted_fgraph
)
fgraph_jit
=
self
.
jit_compile
(
converted_fgraph
)
def
thunk
(
def
thunk
(
fgraph
=
self
.
fgraph
,
fgraph_jit
=
fgraph_jit
,
fgraph_jit
=
fgraph_jit
,
thunk_inputs
=
thunk_inputs
,
thunk_inputs
=
thunk_inputs
,
thunk_outputs
=
thunk_outputs
,
thunk_outputs
=
thunk_outputs
,
):
):
outputs
=
fgraph_jit
(
*
[
self
.
input_filter
(
x
[
0
])
for
x
in
thunk_inputs
])
try
:
outputs
=
fgraph_jit
(
*
(
x
[
0
]
for
x
in
thunk_inputs
))
except
Exception
:
# TODO: Should we add a fake node that combines all outputs,
# since the error may come from any of them?
raise_with_op
(
self
.
fgraph
,
output_nodes
[
0
],
thunk
)
# strict=False because we are in a hot loop
# strict=False because we are in a hot loop
for
o_var
,
o_storage
,
o_val
in
zip
(
for
o_storage
,
o_val
in
zip
(
thunk_outputs
,
outputs
,
strict
=
False
):
fgraph
.
outputs
,
thunk_outputs
,
outputs
,
strict
=
False
o_storage
[
0
]
=
o_val
):
compute_map
[
o_var
][
0
]
=
True
o_storage
[
0
]
=
self
.
output_filter
(
o_var
,
o_val
)
return
outputs
thunk
.
inputs
=
thunk_inputs
thunk
.
inputs
=
thunk_inputs
thunk
.
outputs
=
thunk_outputs
thunk
.
outputs
=
thunk_outputs
thunk
.
lazy
=
False
thunk
.
lazy
=
False
thunks
.
append
(
thunk
)
thunks
=
[
thunk
]
return
thunks
,
output_nodes
,
fgraph_jit
return
thunks
,
output_nodes
,
fgraph_jit
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
fgraph
=
self
.
fgraph
fgraph
=
self
.
fgraph
nodes
=
self
.
schedule
(
fgraph
)
nodes
=
self
.
schedule
(
fgraph
)
no_recycling
=
self
.
no_recycling
input_storage
,
output_storage
,
storage_map
=
map_storage
(
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
nodes
,
input_storage
,
output_storage
,
storage_map
fgraph
,
nodes
,
input_storage
,
output_storage
,
storage_map
...
@@ -701,34 +696,7 @@ class JITLinker(PerformLinker):
...
@@ -701,34 +696,7 @@ class JITLinker(PerformLinker):
compute_map
,
nodes
,
input_storage
,
output_storage
,
storage_map
compute_map
,
nodes
,
input_storage
,
output_storage
,
storage_map
)
)
computed
,
last_user
=
gc_helper
(
nodes
)
[
fn
]
=
thunks
if
self
.
allow_gc
:
post_thunk_old_storage
=
[
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
for
node
in
nodes
]
else
:
post_thunk_old_storage
=
None
if
no_recycling
is
True
:
no_recycling
=
list
(
storage_map
.
values
())
no_recycling
=
difference
(
no_recycling
,
input_storage
)
else
:
no_recycling
=
[
storage_map
[
r
]
for
r
in
no_recycling
if
r
not
in
fgraph
.
inputs
]
fn
=
streamline
(
fgraph
,
thunks
,
nodes
,
post_thunk_old_storage
,
no_recycling
=
no_recycling
)
fn
.
jit_fn
=
jit_fn
fn
.
jit_fn
=
jit_fn
fn
.
allow_gc
=
self
.
allow_gc
fn
.
allow_gc
=
self
.
allow_gc
fn
.
storage_map
=
storage_map
fn
.
storage_map
=
storage_map
...
...
pytensor/link/numba/linker.py
浏览文件 @
a0c64b5f
from
typing
import
TYPE_CHECKING
,
Any
import
numpy
as
np
import
pytensor
from
pytensor.link.basic
import
JITLinker
from
pytensor.link.basic
import
JITLinker
if
TYPE_CHECKING
:
from
pytensor.graph.basic
import
Variable
class
NumbaLinker
(
JITLinker
):
class
NumbaLinker
(
JITLinker
):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def
output_filter
(
self
,
var
:
"Variable"
,
out
:
Any
)
->
Any
:
if
not
isinstance
(
var
,
np
.
ndarray
)
and
isinstance
(
var
.
type
,
pytensor
.
tensor
.
TensorType
):
return
var
.
type
.
filter
(
out
,
allow_downcast
=
True
)
return
out
def
fgraph_convert
(
self
,
fgraph
,
**
kwargs
):
def
fgraph_convert
(
self
,
fgraph
,
**
kwargs
):
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch
import
numba_funcify
...
...
pytensor/link/pytorch/linker.py
浏览文件 @
a0c64b5f
import
copy
from
typing
import
Any
from
pytensor.graph.basic
import
Variable
from
pytensor.link.basic
import
JITLinker
from
pytensor.link.basic
import
JITLinker
from
pytensor.link.utils
import
unique_name_generator
from
pytensor.link.utils
import
unique_name_generator
...
@@ -13,14 +9,6 @@ class PytorchLinker(JITLinker):
...
@@ -13,14 +9,6 @@ class PytorchLinker(JITLinker):
super
()
.
__init__
(
*
args
,
**
kwargs
)
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
gen_functors
=
[]
self
.
gen_functors
=
[]
def
input_filter
(
self
,
inp
:
Any
)
->
Any
:
from
pytensor.link.pytorch.dispatch
import
pytorch_typify
return
pytorch_typify
(
inp
)
def
output_filter
(
self
,
var
:
Variable
,
out
:
Any
)
->
Any
:
return
out
.
cpu
()
def
fgraph_convert
(
self
,
fgraph
,
input_storage
,
storage_map
,
**
kwargs
):
def
fgraph_convert
(
self
,
fgraph
,
input_storage
,
storage_map
,
**
kwargs
):
from
pytensor.link.pytorch.dispatch
import
pytorch_funcify
from
pytensor.link.pytorch.dispatch
import
pytorch_funcify
...
@@ -49,6 +37,8 @@ class PytorchLinker(JITLinker):
...
@@ -49,6 +37,8 @@ class PytorchLinker(JITLinker):
def
jit_compile
(
self
,
fn
):
def
jit_compile
(
self
,
fn
):
import
torch
import
torch
from
pytensor.link.pytorch.dispatch
import
pytorch_typify
class
wrapper
:
class
wrapper
:
"""
"""
Pytorch would fail compiling our method when trying
Pytorch would fail compiling our method when trying
...
@@ -62,7 +52,7 @@ class PytorchLinker(JITLinker):
...
@@ -62,7 +52,7 @@ class PytorchLinker(JITLinker):
def
__init__
(
self
,
fn
,
gen_functors
):
def
__init__
(
self
,
fn
,
gen_functors
):
self
.
fn
=
torch
.
compile
(
fn
)
self
.
fn
=
torch
.
compile
(
fn
)
self
.
gen_functors
=
copy
.
copy
(
gen_functors
)
self
.
gen_functors
=
gen_functors
.
copy
(
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
import
pytensor.link.utils
import
pytensor.link.utils
...
@@ -83,9 +73,15 @@ class PytorchLinker(JITLinker):
...
@@ -83,9 +73,15 @@ class PytorchLinker(JITLinker):
def
__del__
(
self
):
def
__del__
(
self
):
del
self
.
gen_functors
del
self
.
gen_functors
res
=
wrapper
(
fn
,
self
.
gen_functors
)
inner_fn
=
wrapper
(
fn
,
self
.
gen_functors
)
self
.
gen_functors
=
[]
self
.
gen_functors
=
[]
return
res
# Torch does not accept numpy inputs and may return GPU objects
def
fn
(
*
inputs
,
inner_fn
=
inner_fn
):
outs
=
inner_fn
(
*
(
pytorch_typify
(
inp
)
for
inp
in
inputs
))
return
tuple
(
out
.
cpu
()
.
numpy
()
for
out
in
outs
)
return
fn
def
create_thunk_inputs
(
self
,
storage_map
):
def
create_thunk_inputs
(
self
,
storage_map
):
thunk_inputs
=
[]
thunk_inputs
=
[]
...
...
tests/link/numba/test_basic.py
浏览文件 @
a0c64b5f
...
@@ -889,3 +889,20 @@ def test_cache_warning_suppressed():
...
@@ -889,3 +889,20 @@ def test_cache_warning_suppressed():
x_test
=
np
.
random
.
uniform
(
size
=
5
)
x_test
=
np
.
random
.
uniform
(
size
=
5
)
np
.
testing
.
assert_allclose
(
fn
(
x_test
),
scipy
.
special
.
psi
(
x_test
)
*
2
)
np
.
testing
.
assert_allclose
(
fn
(
x_test
),
scipy
.
special
.
psi
(
x_test
)
*
2
)
@pytest.mark.parametrize
(
"mode"
,
(
"default"
,
"trust_input"
,
"direct"
))
def
test_function_overhead
(
mode
,
benchmark
):
x
=
pt
.
vector
(
"x"
)
out
=
pt
.
exp
(
x
)
fn
=
function
([
x
],
out
,
mode
=
"NUMBA"
)
if
mode
==
"trust_input"
:
fn
.
trust_input
=
True
elif
mode
==
"direct"
:
fn
=
fn
.
vm
.
jit_fn
test_x
=
np
.
zeros
(
1000
)
assert
np
.
sum
(
fn
(
test_x
))
==
1000
benchmark
(
fn
,
test_x
)
tests/link/pytorch/test_basic.py
浏览文件 @
a0c64b5f
...
@@ -53,8 +53,6 @@ def compare_pytorch_and_py(
...
@@ -53,8 +53,6 @@ def compare_pytorch_and_py(
assert_fn: func, opt
assert_fn: func, opt
Assert function used to check for equality between python and pytorch. If not
Assert function used to check for equality between python and pytorch. If not
provided uses np.testing.assert_allclose
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks if torch.device.type is cuda
"""
"""
...
@@ -66,20 +64,19 @@ def compare_pytorch_and_py(
...
@@ -66,20 +64,19 @@ def compare_pytorch_and_py(
pytensor_torch_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
pytorch_mode
)
pytensor_torch_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
pytorch_mode
)
pytorch_res
=
pytensor_torch_fn
(
*
test_inputs
)
pytorch_res
=
pytensor_torch_fn
(
*
test_inputs
)
if
must_be_device_array
:
if
isinstance
(
pytorch_res
,
list
):
if
isinstance
(
pytorch_res
,
list
):
assert
all
(
isinstance
(
res
,
torch
.
Tensor
)
for
res
in
pytorch_res
)
assert
all
(
isinstance
(
res
,
np
.
ndarray
)
for
res
in
pytorch_res
)
else
:
else
:
assert
pytorch_res
.
device
.
type
==
"cuda"
assert
isinstance
(
pytorch_res
,
np
.
ndarray
)
pytensor_py_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
pytensor_py_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
if
len
(
fgraph
.
outputs
)
>
1
:
if
len
(
fgraph
.
outputs
)
>
1
:
for
pytorch_res_i
,
py_res_i
in
zip
(
pytorch_res
,
py_res
,
strict
=
True
):
for
pytorch_res_i
,
py_res_i
in
zip
(
pytorch_res
,
py_res
,
strict
=
True
):
assert_fn
(
pytorch_res_i
.
detach
()
.
cpu
()
.
numpy
()
,
py_res_i
)
assert_fn
(
pytorch_res_i
,
py_res_i
)
else
:
else
:
assert_fn
(
pytorch_res
[
0
]
.
detach
()
.
cpu
()
.
numpy
()
,
py_res
[
0
])
assert_fn
(
pytorch_res
[
0
],
py_res
[
0
])
return
pytensor_torch_fn
,
pytorch_res
return
pytensor_torch_fn
,
pytorch_res
...
@@ -162,23 +159,23 @@ def test_shared(device):
...
@@ -162,23 +159,23 @@ def test_shared(device):
pytensor_torch_fn
=
function
([],
a
,
mode
=
"PYTORCH"
)
pytensor_torch_fn
=
function
([],
a
,
mode
=
"PYTORCH"
)
pytorch_res
=
pytensor_torch_fn
()
pytorch_res
=
pytensor_torch_fn
()
assert
isinstance
(
pytorch_res
,
torch
.
Tensor
)
assert
isinstance
(
pytorch_res
,
np
.
ndarray
)
assert
isinstance
(
a
.
get_value
(),
np
.
ndarray
)
assert
isinstance
(
a
.
get_value
(),
np
.
ndarray
)
np
.
testing
.
assert_allclose
(
pytorch_res
.
cpu
()
,
a
.
get_value
())
np
.
testing
.
assert_allclose
(
pytorch_res
,
a
.
get_value
())
pytensor_torch_fn
=
function
([],
a
*
2
,
mode
=
"PYTORCH"
)
pytensor_torch_fn
=
function
([],
a
*
2
,
mode
=
"PYTORCH"
)
pytorch_res
=
pytensor_torch_fn
()
pytorch_res
=
pytensor_torch_fn
()
assert
isinstance
(
pytorch_res
,
torch
.
Tensor
)
assert
isinstance
(
pytorch_res
,
np
.
ndarray
)
assert
isinstance
(
a
.
get_value
(),
np
.
ndarray
)
assert
isinstance
(
a
.
get_value
(),
np
.
ndarray
)
np
.
testing
.
assert_allclose
(
pytorch_res
.
cpu
()
,
a
.
get_value
()
*
2
)
np
.
testing
.
assert_allclose
(
pytorch_res
,
a
.
get_value
()
*
2
)
new_a_value
=
np
.
array
([
3
,
4
,
5
],
dtype
=
config
.
floatX
)
new_a_value
=
np
.
array
([
3
,
4
,
5
],
dtype
=
config
.
floatX
)
a
.
set_value
(
new_a_value
)
a
.
set_value
(
new_a_value
)
pytorch_res
=
pytensor_torch_fn
()
pytorch_res
=
pytensor_torch_fn
()
assert
isinstance
(
pytorch_res
,
torch
.
Tensor
)
assert
isinstance
(
pytorch_res
,
np
.
ndarray
)
np
.
testing
.
assert_allclose
(
pytorch_res
.
cpu
()
,
new_a_value
*
2
)
np
.
testing
.
assert_allclose
(
pytorch_res
,
new_a_value
*
2
)
@pytest.mark.parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
@pytest.mark.parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
...
@@ -225,7 +222,7 @@ def test_alloc_and_empty():
...
@@ -225,7 +222,7 @@ def test_alloc_and_empty():
fn
=
function
([
dim1
],
out
,
mode
=
pytorch_mode
)
fn
=
function
([
dim1
],
out
,
mode
=
pytorch_mode
)
res
=
fn
(
7
)
res
=
fn
(
7
)
assert
res
.
shape
==
(
5
,
7
,
3
)
assert
res
.
shape
==
(
5
,
7
,
3
)
assert
res
.
dtype
==
torch
.
float32
assert
res
.
dtype
==
np
.
float32
v
=
vector
(
"v"
,
shape
=
(
3
,),
dtype
=
"float64"
)
v
=
vector
(
"v"
,
shape
=
(
3
,),
dtype
=
"float64"
)
out
=
alloc
(
v
,
dim0
,
dim1
,
3
)
out
=
alloc
(
v
,
dim0
,
dim1
,
3
)
...
...
tests/link/pytorch/test_elemwise.py
浏览文件 @
a0c64b5f
...
@@ -152,7 +152,7 @@ def test_cast():
...
@@ -152,7 +152,7 @@ def test_cast():
_
,
[
res
]
=
compare_pytorch_and_py
(
_
,
[
res
]
=
compare_pytorch_and_py
(
fgraph
,
[
np
.
arange
(
6
,
dtype
=
"float32"
)
.
reshape
(
2
,
3
)]
fgraph
,
[
np
.
arange
(
6
,
dtype
=
"float32"
)
.
reshape
(
2
,
3
)]
)
)
assert
res
.
dtype
==
torch
.
int32
assert
res
.
dtype
==
np
.
int32
def
test_vmap_elemwise
():
def
test_vmap_elemwise
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论