Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9859c799
提交
9859c799
authored
4月 13, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Generalize FunctionGraph conversion function with aesara.link.utils.fgraph_to_python
上级
e6914913
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
162 行增加
和
99 行删除
+162
-99
dispatch.py
aesara/link/jax/dispatch.py
+14
-96
linker.py
aesara/link/jax/linker.py
+3
-1
utils.py
aesara/link/utils.py
+145
-2
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
9859c799
import
ast
import
re
import
warnings
import
warnings
from
collections
import
Counter
from
functools
import
reduce
,
singledispatch
from
functools
import
reduce
,
singledispatch
from
keyword
import
iskeyword
from
tempfile
import
NamedTemporaryFile
from
textwrap
import
indent
from
types
import
FunctionType
from
warnings
import
warn
from
warnings
import
warn
import
jax
import
jax
...
@@ -17,10 +10,9 @@ from numpy.random import RandomState
...
@@ -17,10 +10,9 @@ from numpy.random import RandomState
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.compile.ops
import
DeepCopyOp
,
ViewOp
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.ifelse
import
IfElse
from
aesara.ifelse
import
IfElse
from
aesara.link.utils
import
map_storage
from
aesara.link.utils
import
fgraph_to_python
from
aesara.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
aesara.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
from
aesara.scan.utils
import
scan_args
as
ScanArgs
from
aesara.scan.utils
import
scan_args
as
ScanArgs
...
@@ -104,7 +96,7 @@ incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
...
@@ -104,7 +96,7 @@ incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
@singledispatch
@singledispatch
def
jax_typify
(
data
,
dtype
):
def
jax_typify
(
data
,
dtype
=
None
,
**
kwargs
):
"""Convert instances of Aesara `Type`s to JAX types."""
"""Convert instances of Aesara `Type`s to JAX types."""
if
dtype
is
None
:
if
dtype
is
None
:
return
data
return
data
...
@@ -113,12 +105,12 @@ def jax_typify(data, dtype):
...
@@ -113,12 +105,12 @@ def jax_typify(data, dtype):
@jax_typify.register
(
np
.
ndarray
)
@jax_typify.register
(
np
.
ndarray
)
def
jax_typify_ndarray
(
data
,
dtype
):
def
jax_typify_ndarray
(
data
,
dtype
=
None
,
**
kwargs
):
return
jnp
.
array
(
data
,
dtype
=
dtype
)
return
jnp
.
array
(
data
,
dtype
=
dtype
)
@jax_typify.register
(
RandomState
)
@jax_typify.register
(
RandomState
)
def
jax_typify_RandomState
(
state
,
dtype
):
def
jax_typify_RandomState
(
state
,
**
kwargs
):
state
=
state
.
get_state
(
legacy
=
False
)
state
=
state
.
get_state
(
legacy
=
False
)
state
[
"bit_generator"
]
=
numpy_bit_gens
[
state
[
"bit_generator"
]]
state
[
"bit_generator"
]
=
numpy_bit_gens
[
state
[
"bit_generator"
]]
return
state
return
state
...
@@ -608,92 +600,18 @@ def jax_funcify_FunctionGraph(
...
@@ -608,92 +600,18 @@ def jax_funcify_FunctionGraph(
storage_map
=
None
,
storage_map
=
None
,
**
kwargs
,
**
kwargs
,
):
):
return
fgraph_to_python
(
if
order
is
None
:
fgraph
,
order
=
fgraph
.
toposort
()
jax_funcify
,
input_storage
,
output_storage
,
storage_map
=
map_storage
(
jax_typify
,
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
order
,
input_storage
,
output_storage
,
storage_map
,
fgraph_name
=
"jax_funcified_fgraph"
,
**
kwargs
,
)
)
global_env
=
{}
fgraph_name
=
"jax_funcified_fgraph"
def
unique_name
(
x
,
names_counter
=
Counter
([
fgraph_name
]),
obj_to_names
=
{}):
if
x
in
obj_to_names
:
return
obj_to_names
[
x
]
if
isinstance
(
x
,
Variable
):
name
=
re
.
sub
(
"[^0-9a-zA-Z]+"
,
"_"
,
x
.
name
)
if
x
.
name
else
""
name
=
(
name
if
(
name
.
isidentifier
()
and
not
iskeyword
(
name
))
else
x
.
auto_name
)
elif
isinstance
(
x
,
FunctionType
):
name
=
x
.
__name__
else
:
name
=
type
(
x
)
.
__name__
name_suffix
=
names_counter
.
get
(
name
,
""
)
local_name
=
f
"{name}{name_suffix}"
names_counter
.
update
((
name
,))
obj_to_names
[
x
]
=
local_name
return
local_name
body_assigns
=
[]
for
node
in
order
:
jax_func
=
jax_funcify
(
node
.
op
,
node
=
node
,
**
kwargs
)
# Create a local alias with a unique name
local_jax_func_name
=
unique_name
(
jax_func
)
global_env
[
local_jax_func_name
]
=
jax_func
node_input_names
=
[]
for
i
in
node
.
inputs
:
local_input_name
=
unique_name
(
i
)
if
storage_map
[
i
][
0
]
is
not
None
or
isinstance
(
i
,
Constant
):
# Constants need to be assigned locally and referenced
global_env
[
local_input_name
]
=
jax_typify
(
storage_map
[
i
][
0
],
None
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names
.
append
(
local_input_name
)
node_output_names
=
[
unique_name
(
v
)
for
v
in
node
.
outputs
]
body_assigns
.
append
(
f
"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names
=
[
unique_name
(
v
)
for
v
in
fgraph
.
inputs
]
fgraph_output_names
=
[
unique_name
(
v
)
for
v
in
fgraph
.
outputs
]
joined_body_assigns
=
indent
(
"
\n
"
.
join
(
body_assigns
),
" "
)
if
len
(
fgraph_output_names
)
==
1
:
fgraph_return_src
=
f
"({fgraph_output_names[0]},)"
else
:
fgraph_return_src
=
", "
.
join
(
fgraph_output_names
)
fgraph_def_src
=
f
"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast
=
ast
.
parse
(
fgraph_def_src
)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with
NamedTemporaryFile
(
delete
=
False
)
as
f
:
filename
=
f
.
name
f
.
write
(
fgraph_def_src
.
encode
())
mod_code
=
compile
(
fgraph_def_ast
,
filename
,
mode
=
"exec"
)
exec
(
mod_code
,
global_env
,
locals
())
fgraph_def
=
locals
()[
fgraph_name
]
return
fgraph_def
@jax_funcify.register
(
CAReduce
)
@jax_funcify.register
(
CAReduce
)
def
jax_funcify_CAReduce
(
op
,
**
kwargs
):
def
jax_funcify_CAReduce
(
op
,
**
kwargs
):
...
...
aesara/link/jax/linker.py
浏览文件 @
9859c799
...
@@ -69,7 +69,9 @@ class JAXLinker(PerformLinker):
...
@@ -69,7 +69,9 @@ class JAXLinker(PerformLinker):
for
n
in
self
.
fgraph
.
inputs
:
for
n
in
self
.
fgraph
.
inputs
:
sinput
=
storage_map
[
n
]
sinput
=
storage_map
[
n
]
if
isinstance
(
sinput
[
0
],
RandomState
):
if
isinstance
(
sinput
[
0
],
RandomState
):
new_value
=
jax_typify
(
sinput
[
0
],
getattr
(
sinput
[
0
],
"dtype"
,
None
))
new_value
=
jax_typify
(
sinput
[
0
],
dtype
=
getattr
(
sinput
[
0
],
"dtype"
,
None
)
)
# We need to remove the reference-based connection to the
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# subsequent attempts to use the same shared variable within
...
...
aesara/link/utils.py
浏览文件 @
9859c799
import
ast
import
io
import
io
import
re
import
sys
import
sys
import
traceback
import
traceback
import
warnings
import
warnings
from
collections
import
Counter
from
keyword
import
iskeyword
from
operator
import
itemgetter
from
operator
import
itemgetter
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
NoReturn
,
Optional
,
Tuple
,
Union
from
tempfile
import
NamedTemporaryFile
from
textwrap
import
indent
from
types
import
FunctionType
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
NoReturn
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
aesara
import
utils
from
aesara
import
utils
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
...
@@ -564,3 +571,139 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
...
@@ -564,3 +571,139 @@ def register_thunk_trace_excepthook(handler: io.TextIOWrapper = sys.stdout):
register_thunk_trace_excepthook
()
register_thunk_trace_excepthook
()
def
fgraph_to_python
(
fgraph
:
FunctionGraph
,
op_conversion_fn
:
Callable
,
type_conversion_fn
:
Optional
[
Callable
]
=
lambda
x
,
**
kwargs
:
x
,
order
:
Optional
[
List
[
Variable
]]
=
None
,
input_storage
:
Optional
[
List
[
Any
]]
=
None
,
output_storage
:
Optional
[
List
[
Any
]]
=
None
,
storage_map
:
Optional
[
Dict
[
Variable
,
List
[
Any
]]]
=
None
,
fgraph_name
:
str
=
"fgraph_to_python"
,
global_env
:
Optional
[
Dict
[
Any
,
Any
]]
=
None
,
local_env
:
Optional
[
Dict
[
Any
,
Any
]]
=
None
,
**
kwargs
,
)
->
FunctionType
:
"""Convert a ``FunctionGraph`` into a regular Python function.
Parameters
==========
fgraph
The ``FunctionGraph`` to convert.
op_conversion_fn
A callable used to convert nodes inside `fgraph` based on their ``Op``
types. It must have the signature ``(Op, **kwargs)``. One of the
keyword arguments will be ``node``, which provides the ``Apply`` node.
type_conversion_fn
A callable used to convert the values in `storage_map`.
order
The ``order`` argument to ``map_storage``.
input_storage
The ``input_storage`` argument to ``map_storage``.
output_storage
The ``output_storage`` argument to ``map_storage``.
storage_map
The ``storage_map`` argument to ``map_storage``.
fgraph_name
The name used for the resulting function.
global_env
The global environment used when the function is constructed.
The default is an empty ``dict``.
local_env
The local environment used when the function is constructed.
The default is ``locals()``.
**kwargs
The remaining keywords are passed to `python_conversion_fn`
"""
if
order
is
None
:
order
=
fgraph
.
toposort
()
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
)
if
global_env
is
None
:
global_env
=
{}
def
unique_name
(
x
,
names_counter
=
Counter
([
fgraph_name
]),
obj_to_names
=
{}):
if
x
in
obj_to_names
:
return
obj_to_names
[
x
]
if
isinstance
(
x
,
Variable
):
name
=
re
.
sub
(
"[^0-9a-zA-Z]+"
,
"_"
,
x
.
name
)
if
x
.
name
else
""
name
=
(
name
if
(
name
.
isidentifier
()
and
not
iskeyword
(
name
))
else
x
.
auto_name
)
elif
isinstance
(
x
,
FunctionType
):
name
=
x
.
__name__
else
:
name
=
type
(
x
)
.
__name__
name_suffix
=
names_counter
.
get
(
name
,
""
)
local_name
=
f
"{name}{name_suffix}"
names_counter
.
update
((
name
,))
obj_to_names
[
x
]
=
local_name
return
local_name
body_assigns
=
[]
for
node
in
order
:
jax_func
=
op_conversion_fn
(
node
.
op
,
node
=
node
,
**
kwargs
)
# Create a local alias with a unique name
local_jax_func_name
=
unique_name
(
jax_func
)
global_env
[
local_jax_func_name
]
=
jax_func
node_input_names
=
[]
for
i
in
node
.
inputs
:
local_input_name
=
unique_name
(
i
)
if
storage_map
[
i
][
0
]
is
not
None
or
isinstance
(
i
,
Constant
):
# Constants need to be assigned locally and referenced
global_env
[
local_input_name
]
=
type_conversion_fn
(
storage_map
[
i
][
0
],
node
=
None
,
**
kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names
.
append
(
local_input_name
)
node_output_names
=
[
unique_name
(
v
)
for
v
in
node
.
outputs
]
body_assigns
.
append
(
f
"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})"
)
fgraph_input_names
=
[
unique_name
(
v
)
for
v
in
fgraph
.
inputs
]
fgraph_output_names
=
[
unique_name
(
v
)
for
v
in
fgraph
.
outputs
]
joined_body_assigns
=
indent
(
"
\n
"
.
join
(
body_assigns
),
" "
)
if
len
(
fgraph_output_names
)
==
1
:
fgraph_return_src
=
f
"({fgraph_output_names[0]},)"
else
:
fgraph_return_src
=
", "
.
join
(
fgraph_output_names
)
fgraph_def_src
=
f
"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
"""
fgraph_def_ast
=
ast
.
parse
(
fgraph_def_src
)
# Create source code to be (at least temporarily) associated with the
# compiled function (e.g. for easier debugging)
with
NamedTemporaryFile
(
delete
=
False
)
as
f
:
filename
=
f
.
name
f
.
write
(
fgraph_def_src
.
encode
())
if
local_env
is
None
:
local_env
=
locals
()
mod_code
=
compile
(
fgraph_def_ast
,
filename
,
mode
=
"exec"
)
exec
(
mod_code
,
global_env
,
local_env
)
fgraph_def
=
local_env
[
fgraph_name
]
return
fgraph_def
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论