Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1549649f
提交
1549649f
authored
4月 13, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Create a generalized JITLinker
上级
9859c799
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
186 行增加
和
172 行删除
+186
-172
basic.py
aesara/link/basic.py
+171
-4
linker.py
aesara/link/jax/linker.py
+15
-168
没有找到文件。
aesara/link/basic.py
浏览文件 @
1549649f
from
abc
import
ABC
,
abstractmethod
from
copy
import
copy
,
deepcopy
from
copy
import
copy
,
deepcopy
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
...
@@ -146,7 +147,7 @@ class Container:
...
@@ -146,7 +147,7 @@ class Container:
return
r
return
r
class
Linker
:
class
Linker
(
ABC
)
:
"""
"""
Base type for all linkers.
Base type for all linkers.
...
@@ -189,6 +190,7 @@ class Linker:
...
@@ -189,6 +190,7 @@ class Linker:
new
.
_allow_gc
=
allow_gc
new
.
_allow_gc
=
allow_gc
return
new
return
new
@abstractmethod
def
make_thunk
(
self
,
**
kwargs
)
->
ThunkType
:
def
make_thunk
(
self
,
**
kwargs
)
->
ThunkType
:
"""
"""
This function must return a triplet (function, input_variables,
This function must return a triplet (function, input_variables,
...
@@ -211,9 +213,6 @@ class Linker:
...
@@ -211,9 +213,6 @@ class Linker:
print e.data # 3.0 iff inplace == True (else unknown)
print e.data # 3.0 iff inplace == True (else unknown)
"""
"""
raise
NotImplementedError
(
f
"make_thunk method of {type(self)} is not implemented."
)
@deprecated
(
"Marked for deletion. Only tests use it."
)
@deprecated
(
"Marked for deletion. Only tests use it."
)
def
make_function
(
self
,
unpack_single
:
bool
=
True
,
**
kwargs
)
->
Callable
:
def
make_function
(
self
,
unpack_single
:
bool
=
True
,
**
kwargs
)
->
Callable
:
...
@@ -630,3 +629,171 @@ def WrapLinkerMany(
...
@@ -630,3 +629,171 @@ def WrapLinkerMany(
f
(
*
args
)
f
(
*
args
)
return
WrapLinker
(
linkers
,
wrapper
)
return
WrapLinker
(
linkers
,
wrapper
)
class
JITLinker
(
PerformLinker
):
"""A ``Linker`` that JIT compiles a ``FunctionGraph`` into a single runnable thunk.
The entirety of ``Linker.fgraph`` is converted into a single JIT compiled
thunk that is run by an Aesara ``VM``.
"""
@abstractmethod
def
fgraph_convert
(
self
,
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
,
**
kwargs
):
"""Convert a ``FunctionGraph`` into a JIT-able function."""
@abstractmethod
def
create_thunk_inputs
(
self
,
storage_map
:
Dict
[
Variable
,
List
[
Any
]])
->
List
[
Any
]:
"""Pre-process inputs for the generated thunk.
Parameters
==========
storage_map
A ``dict`` mapping ``Variable``s to their storage lists.
Returns
=======
A list of thunk inputs
"""
@abstractmethod
def
jit_compile
(
self
,
fn
:
Callable
)
->
Callable
:
"""JIT compile a converted ``FunctionGraph``."""
def
create_jitable_thunk
(
self
,
compute_map
,
order
,
input_storage
,
output_storage
,
storage_map
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
----------
compute_map: dict
The compute map dictionary.
order
input_storage
output_storage
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
output_nodes
=
[
o
.
owner
for
o
in
self
.
fgraph
.
outputs
]
converted_fgraph
=
self
.
fgraph_convert
(
self
.
fgraph
,
order
=
order
,
input_storage
=
input_storage
,
output_storage
=
output_storage
,
storage_map
=
storage_map
,
)
thunk_inputs
=
self
.
create_thunk_inputs
(
storage_map
)
thunks
=
[]
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
self
.
fgraph
.
outputs
]
fgraph_jit
=
self
.
jit_compile
(
converted_fgraph
)
def
thunk
(
fgraph
=
self
.
fgraph
,
fgraph_jit
=
fgraph_jit
,
thunk_inputs
=
thunk_inputs
,
thunk_outputs
=
thunk_outputs
,
):
outputs
=
fgraph_jit
(
*
[
x
[
0
]
for
x
in
thunk_inputs
])
for
o_node
,
o_storage
,
o_val
in
zip
(
fgraph
.
outputs
,
thunk_outputs
,
outputs
):
compute_map
[
o_node
][
0
]
=
True
if
len
(
o_storage
)
>
1
:
assert
len
(
o_storage
)
==
len
(
o_val
)
for
i
,
o_sub_val
in
enumerate
(
o_val
):
o_storage
[
i
]
=
o_sub_val
else
:
o_storage
[
0
]
=
o_val
return
outputs
thunk
.
inputs
=
thunk_inputs
thunk
.
outputs
=
thunk_outputs
thunk
.
lazy
=
False
thunks
.
append
(
thunk
)
# This is a bit hackish, but we only return one of the output nodes
return
thunks
,
output_nodes
[:
1
]
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
fgraph
=
self
.
fgraph
nodes
=
self
.
schedule
(
fgraph
)
no_recycling
=
self
.
no_recycling
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
nodes
,
input_storage
,
output_storage
,
storage_map
)
compute_map
=
{}
for
k
in
storage_map
:
compute_map
[
k
]
=
[
k
.
owner
is
None
]
thunks
,
nodes
=
self
.
create_jitable_thunk
(
compute_map
,
nodes
,
input_storage
,
output_storage
,
storage_map
)
computed
,
last_user
=
gc_helper
(
nodes
)
if
self
.
allow_gc
:
post_thunk_old_storage
=
[]
for
node
in
nodes
:
post_thunk_old_storage
.
append
(
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
)
else
:
post_thunk_old_storage
=
None
if
no_recycling
is
True
:
no_recycling
=
list
(
storage_map
.
values
())
no_recycling
=
difference
(
no_recycling
,
input_storage
)
else
:
no_recycling
=
[
storage_map
[
r
]
for
r
in
no_recycling
if
r
not
in
fgraph
.
inputs
]
fn
=
streamline
(
fgraph
,
thunks
,
nodes
,
post_thunk_old_storage
,
no_recycling
=
no_recycling
)
fn
.
allow_gc
=
self
.
allow_gc
fn
.
storage_map
=
storage_map
return
(
fn
,
[
Container
(
input
,
storage
)
for
input
,
storage
in
zip
(
fgraph
.
inputs
,
input_storage
)
],
[
Container
(
output
,
storage
,
readonly
=
True
)
for
output
,
storage
in
zip
(
fgraph
.
outputs
,
output_storage
)
],
thunks
,
nodes
,
)
aesara/link/jax/linker.py
浏览文件 @
1549649f
from
warnings
import
warn
from
numpy.random
import
RandomState
from
numpy.random
import
RandomState
from
aesara.graph.basic
import
Constant
from
aesara.graph.basic
import
Constant
from
aesara.link.basic
import
Container
,
PerformLinker
from
aesara.link.basic
import
JITLinker
from
aesara.link.utils
import
gc_helper
,
map_storage
,
streamline
from
aesara.utils
import
difference
class
JAXLinker
(
PerformLinker
):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.
Attributes
----------
allow_non_jax: bool
A boolean indicating whether or not an exception is thrown when the
graph cannot be JAX compiled (e.g. the graph has an unsupported operator).
If `allow_non_jax` is `True`, the fallback is currently Python compilation.
"""
class
JAXLinker
(
JITLinker
):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
allow_non_jax
=
False
def
fgraph_convert
(
self
,
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
,
**
kwargs
def
create_jax_thunks
(
self
,
compute_map
,
order
,
input_storage
,
output_storage
,
storage_map
):
):
"""Create a thunk for each output of the `Linker`s `FunctionGraph`.
from
aesara.link.jax.dispatch
import
jax_funcify
This is differs from the other thunk-making function in that it only
produces thunks for the `FunctionGraph` output nodes.
Parameters
return
jax_funcify
(
----------
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
,
**
kwargs
compute_map: dict
)
The compute map dictionary.
storage_map: dict
The storage map dictionary.
Returns
-------
thunks: list
A tuple containing the thunks.
output_nodes: list and their
A tuple containing the output nodes.
"""
def
jit_compile
(
self
,
fn
):
import
jax
import
jax
from
aesara.link.jax.dispatch
import
jax_funcify
,
jax_typify
output_nodes
=
[
o
.
owner
for
o
in
self
.
fgraph
.
outputs
]
# Create a JAX-compilable function from our `FunctionGraph`
jaxed_fgraph
=
jax_funcify
(
self
.
fgraph
,
input_storage
=
input_storage
,
output_storage
=
output_storage
,
storage_map
=
storage_map
,
)
# I suppose we can consider `Constant`s to be "static" according to
# I suppose we can consider `Constant`s to be "static" according to
# JAX.
# JAX.
static_argnums
=
[
static_argnums
=
[
n
for
n
,
i
in
enumerate
(
self
.
fgraph
.
inputs
)
if
isinstance
(
i
,
Constant
)
n
for
n
,
i
in
enumerate
(
self
.
fgraph
.
inputs
)
if
isinstance
(
i
,
Constant
)
]
]
return
jax
.
jit
(
fn
,
static_argnums
)
def
create_thunk_inputs
(
self
,
storage_map
):
from
aesara.link.jax.dispatch
import
jax_typify
thunk_inputs
=
[]
thunk_inputs
=
[]
for
n
in
self
.
fgraph
.
inputs
:
for
n
in
self
.
fgraph
.
inputs
:
...
@@ -79,121 +43,4 @@ class JAXLinker(PerformLinker):
...
@@ -79,121 +43,4 @@ class JAXLinker(PerformLinker):
sinput
=
[
new_value
]
sinput
=
[
new_value
]
thunk_inputs
.
append
(
sinput
)
thunk_inputs
.
append
(
sinput
)
thunks
=
[]
return
thunk_inputs
thunk_outputs
=
[
storage_map
[
n
]
for
n
in
self
.
fgraph
.
outputs
]
fgraph_jit
=
jax
.
jit
(
jaxed_fgraph
,
static_argnums
)
def
thunk
(
fgraph
=
self
.
fgraph
,
fgraph_jit
=
fgraph_jit
,
thunk_inputs
=
thunk_inputs
,
thunk_outputs
=
thunk_outputs
,
):
outputs
=
fgraph_jit
(
*
[
x
[
0
]
for
x
in
thunk_inputs
])
for
o_node
,
o_storage
,
o_val
in
zip
(
fgraph
.
outputs
,
thunk_outputs
,
outputs
):
compute_map
[
o_node
][
0
]
=
True
if
len
(
o_storage
)
>
1
:
assert
len
(
o_storage
)
==
len
(
o_val
)
for
i
,
o_sub_val
in
enumerate
(
o_val
):
o_storage
[
i
]
=
o_sub_val
else
:
o_storage
[
0
]
=
o_val
return
outputs
thunk
.
inputs
=
thunk_inputs
thunk
.
outputs
=
thunk_outputs
thunk
.
lazy
=
False
thunks
.
append
(
thunk
)
# This is a bit hackish, but we only return one of the output nodes
return
thunks
,
output_nodes
[:
1
]
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
fgraph
=
self
.
fgraph
nodes
=
self
.
schedule
(
fgraph
)
no_recycling
=
self
.
no_recycling
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
nodes
,
input_storage
,
output_storage
,
storage_map
)
compute_map
=
{}
for
k
in
storage_map
:
compute_map
[
k
]
=
[
k
.
owner
is
None
]
try
:
# We need to create thunk functions that will populate the output
# storage arrays with the JAX-computed values.
thunks
,
nodes
=
self
.
create_jax_thunks
(
compute_map
,
nodes
,
input_storage
,
output_storage
,
storage_map
)
except
NotImplementedError
as
e
:
if
not
self
.
allow_non_jax
:
raise
warn
(
f
"JaxLinker could not JAXify graph: {e}"
)
thunks
=
[]
for
node
in
nodes
:
thunk
=
node
.
op
.
make_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
,
"py"
)
thunk_inputs
=
[
storage_map
[
v
]
for
v
in
node
.
inputs
]
thunk_outputs
=
[
storage_map
[
v
]
for
v
in
node
.
outputs
]
thunk
.
inputs
=
thunk_inputs
thunk
.
outputs
=
thunk_outputs
thunks
.
append
(
thunk
)
computed
,
last_user
=
gc_helper
(
nodes
)
if
self
.
allow_gc
:
post_thunk_old_storage
=
[]
for
node
in
nodes
:
post_thunk_old_storage
.
append
(
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
)
else
:
post_thunk_old_storage
=
None
if
no_recycling
is
True
:
no_recycling
=
list
(
storage_map
.
values
())
no_recycling
=
difference
(
no_recycling
,
input_storage
)
else
:
no_recycling
=
[
storage_map
[
r
]
for
r
in
no_recycling
if
r
not
in
fgraph
.
inputs
]
fn
=
streamline
(
fgraph
,
thunks
,
nodes
,
post_thunk_old_storage
,
no_recycling
=
no_recycling
)
fn
.
allow_gc
=
self
.
allow_gc
fn
.
storage_map
=
storage_map
return
(
fn
,
[
Container
(
input
,
storage
)
for
input
,
storage
in
zip
(
fgraph
.
inputs
,
input_storage
)
],
[
Container
(
output
,
storage
,
readonly
=
True
)
for
output
,
storage
in
zip
(
fgraph
.
outputs
,
output_storage
)
],
thunks
,
nodes
,
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论