Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
61c40a8d
提交
61c40a8d
authored
12月 06, 2022
作者:
Rémi Louf
提交者:
Thomas Wiecki
12月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite `size` input of `RandomVariable`s in JAX backend
上级
0d1f65f8
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
110 行增加
和
7 行删除
+110
-7
mode.py
pytensor/compile/mode.py
+1
-1
random.py
pytensor/link/jax/dispatch/random.py
+3
-2
shape.py
pytensor/link/jax/dispatch/shape.py
+23
-0
__init__.py
pytensor/tensor/random/rewriting/__init__.py
+8
-0
jax.py
pytensor/tensor/random/rewriting/jax.py
+52
-0
test_basic.py
tests/link/jax/test_basic.py
+1
-1
test_random.py
tests/link/jax/test_random.py
+22
-3
没有找到文件。
pytensor/compile/mode.py
浏览文件 @
61c40a8d
...
...
@@ -449,7 +449,7 @@ else:
JAX
=
Mode
(
JAXLinker
(),
RewriteDatabaseQuery
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
RewriteDatabaseQuery
(
include
=
[
"fast_run"
,
"jax"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
)
NUMBA
=
Mode
(
NumbaLinker
(),
...
...
pytensor/link/jax/dispatch/random.py
浏览文件 @
61c40a8d
...
...
@@ -8,6 +8,7 @@ from numpy.random.bit_generator import ( # type: ignore[attr-defined]
import
pytensor.tensor.random.basic
as
aer
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
,
jax_typify
from
pytensor.link.jax.dispatch.shape
import
JAXShapeTuple
from
pytensor.tensor.shape
import
Shape
,
Shape_i
...
...
@@ -28,7 +29,7 @@ or the shape of an array:
def
assert_size_argument_jax_compatible
(
node
):
"""Assert whether the current node can be
compiled
.
"""Assert whether the current node can be
JIT-compiled by JAX
.
JAX can JIT-compile `jax.random` functions when the `size` argument
is a concrete value, i.e. either a constant or the shape of any
...
...
@@ -37,7 +38,7 @@ def assert_size_argument_jax_compatible(node):
"""
size
=
node
.
inputs
[
1
]
size_op
=
size
.
owner
.
op
if
not
isinstance
(
size_op
,
(
Shape
,
Shape_i
)):
if
not
isinstance
(
size_op
,
(
Shape
,
Shape_i
,
JAXShapeTuple
)):
raise
NotImplementedError
(
SIZE_NOT_COMPATIBLE
)
...
...
pytensor/link/jax/dispatch/shape.py
浏览文件 @
61c40a8d
import
jax.numpy
as
jnp
from
pytensor.graph
import
Constant
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
from
pytensor.tensor.type
import
TensorType
class
JAXShapeTuple
(
Op
):
"""Dummy Op that represents a `size` specified as a tuple."""
def
make_node
(
self
,
*
inputs
):
dtype
=
inputs
[
0
]
.
type
.
dtype
otype
=
TensorType
(
dtype
,
shape
=
(
len
(
inputs
),))
return
Apply
(
self
,
inputs
,
[
otype
()])
def
perform
(
self
,
*
inputs
):
return
tuple
(
inputs
)
@jax_funcify.register
(
JAXShapeTuple
)
def
jax_funcify_JAXShapeTuple
(
op
,
**
kwargs
):
def
shape_tuple_fn
(
*
x
):
return
tuple
(
x
)
return
shape_tuple_fn
@jax_funcify.register
(
Reshape
)
...
...
pytensor/tensor/random/rewriting/__init__.py
浏览文件 @
61c40a8d
# TODO: This is for backward-compatibility; remove when reasonable.
from
pytensor.tensor.random.rewriting.basic
import
*
# isort: off
# Register JAX specializations
import
pytensor.tensor.random.rewriting.jax
# isort: on
pytensor/tensor/random/rewriting/jax.py
0 → 100644
浏览文件 @
61c40a8d
from
pytensor.compile
import
optdb
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.tensor.basic
import
MakeVector
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.random.op
import
RandomVariable
@node_rewriter
([
RandomVariable
])
def
size_parameter_as_tuple
(
fgraph
,
node
):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `size` or `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from
pytensor.link.jax.dispatch.shape
import
JAXShapeTuple
size_arg
=
node
.
inputs
[
1
]
size_node
=
size_arg
.
owner
if
size_node
is
None
:
return
if
isinstance
(
size_node
.
op
,
JAXShapeTuple
):
return
if
isinstance
(
size_node
.
op
,
MakeVector
)
or
(
isinstance
(
size_node
.
op
,
DimShuffle
)
and
size_node
.
op
.
input_broadcastable
==
()
and
size_node
.
op
.
new_order
==
(
"x"
,)
):
# Here PyTensor converted a tuple or list to a tensor
new_size_args
=
JAXShapeTuple
()(
*
size_node
.
inputs
)
new_inputs
=
list
(
node
.
inputs
)
new_inputs
[
1
]
=
new_size_args
new_node
=
node
.
clone_with_new_inputs
(
new_inputs
)
return
new_node
.
outputs
optdb
.
register
(
"jax_size_parameter_as_tuple"
,
in2out
(
size_parameter_as_tuple
),
"jax"
,
position
=
100
)
tests/link/jax/test_basic.py
浏览文件 @
61c40a8d
...
...
@@ -27,7 +27,7 @@ def set_pytensor_flags():
jax
=
pytest
.
importorskip
(
"jax"
)
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabaseQuery
(
include
=
[
"jax"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
...
tests/link/jax/test_random.py
浏览文件 @
61c40a8d
...
...
@@ -454,8 +454,18 @@ def test_random_concrete_shape():
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
2
,
3
)
@pytest.mark.xfail
(
reason
=
"size argument specified as a tuple is a `DimShuffle` node"
)
def
test_random_concrete_shape_subtensor
():
"""JAX should compile when a concrete value is passed for the `size` parameter.
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
rewrite.
JAX does not accept scalars as `size` or `shape` arguments, so this is a
slight improvement over their API.
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
x_at
=
at
.
dmatrix
()
out
=
at
.
random
.
normal
(
0
,
1
,
size
=
x_at
.
shape
[
1
],
rng
=
rng
)
...
...
@@ -463,8 +473,15 @@ def test_random_concrete_shape_subtensor():
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
3
,)
@pytest.mark.xfail
(
reason
=
"size argument specified as a tuple is a `MakeVector` node"
)
def
test_random_concrete_shape_subtensor_tuple
():
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
scalar inputs into tuples of concrete values using the
`jax_size_parameter_as_tuple` rewrite.
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
x_at
=
at
.
dmatrix
()
out
=
at
.
random
.
normal
(
0
,
1
,
size
=
(
x_at
.
shape
[
0
],),
rng
=
rng
)
...
...
@@ -472,7 +489,9 @@ def test_random_concrete_shape_subtensor_tuple():
assert
jax_fn
(
np
.
ones
((
2
,
3
)))
.
shape
==
(
2
,)
@pytest.mark.xfail
(
reason
=
"`size_at` should be specified as a static argument"
)
@pytest.mark.xfail
(
reason
=
"`size_at` should be specified as a static argument"
,
strict
=
True
)
def
test_random_concrete_shape_graph_input
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
size_at
=
at
.
scalar
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论