Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8cf3b20f
提交
8cf3b20f
authored
12月 12, 2022
作者:
Rémi Louf
提交者:
Ricardo Vieira
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor the JAX implementation of `Reshape`
上级
4235ccc3
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
113 行增加
和
24 行删除
+113
-24
shape.py
pytensor/link/jax/dispatch/shape.py
+29
-1
jax.py
pytensor/tensor/rewriting/jax.py
+63
-6
test_shape.py
tests/link/jax/test_shape.py
+21
-17
没有找到文件。
pytensor/link/jax/dispatch/shape.py
浏览文件 @
8cf3b20f
...
...
@@ -28,11 +28,38 @@ def jax_funcify_JAXShapeTuple(op, **kwargs):
return
shape_tuple_fn
SHAPE_NOT_COMPATIBLE
=
"""JAX requires concrete values for the `shape` parameter of `jax.numpy.reshape`.
Concrete values are either constants:
>>> import pytensor.tensor as at
>>> x = at.ones(6)
>>> y = x.reshape((2, 3))
Or the shape of an array:
>>> mat = at.matrix('mat')
>>> y = x.reshape(mat.shape)
"""
def
assert_shape_argument_jax_compatible
(
shape
):
"""Assert whether the current node can be JIT-compiled by JAX.
JAX can JIT-compile functions with a `shape` or `size` argument if it is
given a concrete value, i.e. either a constant or the shape of any traced
value.
"""
shape_op
=
shape
.
owner
.
op
if
not
isinstance
(
shape_op
,
(
Shape
,
Shape_i
,
JAXShapeTuple
)):
raise
NotImplementedError
(
SHAPE_NOT_COMPATIBLE
)
@jax_funcify.register
(
Reshape
)
def
jax_funcify_Reshape
(
op
,
node
,
**
kwargs
):
# JAX reshape only works with constant inputs, otherwise JIT fails
shape
=
node
.
inputs
[
1
]
if
isinstance
(
shape
,
Constant
):
constant_shape
=
shape
.
data
...
...
@@ -40,6 +67,7 @@ def jax_funcify_Reshape(op, node, **kwargs):
return
jnp
.
reshape
(
x
,
constant_shape
)
else
:
assert_shape_argument_jax_compatible
(
shape
)
def
reshape
(
x
,
shape
):
return
jnp
.
reshape
(
x
,
shape
)
...
...
pytensor/tensor/rewriting/jax.py
浏览文件 @
8cf3b20f
import
pytensor.tensor
as
at
from
pytensor.compile
import
optdb
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.tensor.var
import
TensorVariable
import
pytensor.tensor
as
at
from
pytensor.tensor.subtensor
import
AdvancedIncSubtensor
,
AdvancedSubtensor
from
pytensor.tensor.basic
import
MakeVector
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
Sum
from
pytensor.tensor.shape
import
Reshape
from
pytensor.tensor.subtensor
import
AdvancedIncSubtensor
,
AdvancedSubtensor
from
pytensor.tensor.var
import
TensorVariable
@node_rewriter
([
AdvancedIncSubtensor
])
...
...
@@ -24,7 +27,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
if
not
isinstance
(
cond
,
TensorVariable
):
return
if
not
cond
.
type
.
dtype
==
'bool'
:
if
not
cond
.
type
.
dtype
==
"bool"
:
return
if
op
.
set_instead_of_inc
:
...
...
@@ -36,7 +39,10 @@ def boolean_indexing_set_or_inc(fgraph, node):
optdb
.
register
(
"jax_boolean_indexing_set_or_inc"
,
in2out
(
boolean_indexing_set_or_inc
),
"jax"
,
position
=
100
"jax_boolean_indexing_set_or_inc"
,
in2out
(
boolean_indexing_set_or_inc
),
"jax"
,
position
=
100
,
)
...
...
@@ -67,12 +73,63 @@ def boolean_indexing_sum(fgraph, node):
if
not
isinstance
(
cond
,
TensorVariable
):
return
if
not
cond
.
type
.
dtype
==
'bool'
:
if
not
cond
.
type
.
dtype
==
"bool"
:
return
out
=
at
.
sum
(
at
.
where
(
cond
,
x
,
0
))
return
out
.
owner
.
outputs
optdb
.
register
(
"jax_boolean_indexing_sum"
,
in2out
(
boolean_indexing_sum
),
"jax"
,
position
=
100
)
@node_rewriter
([
Reshape
])
def
shape_parameter_as_tuple
(
fgraph
,
node
):
"""Replace `MakeVector` and `DimShuffle` (when used to transform a scalar
into a 1d vector) when they are found as the input of a `shape`
parameter by `JAXShapeTuple` during transpilation.
The JAX implementations of `MakeVector` and `DimShuffle` always return JAX
`TracedArrays`, but JAX only accepts concrete values as inputs for the `size`
or `shape` parameter. When these `Op`s are used to convert scalar or tuple
inputs, however, we can avoid tracing by making them return a tuple of their
inputs instead.
Note that JAX does not accept scalar inputs for the `size` or `shape`
parameters, and this rewrite also ensures that scalar inputs are turned into
tuples during transpilation.
"""
from
pytensor.link.jax.dispatch.shape
import
JAXShapeTuple
shape_arg
=
node
.
inputs
[
1
]
shape_node
=
shape_arg
.
owner
if
shape_node
is
None
:
return
if
isinstance
(
shape_node
.
op
,
JAXShapeTuple
):
return
if
isinstance
(
shape_node
.
op
,
MakeVector
)
or
(
isinstance
(
shape_node
.
op
,
DimShuffle
)
and
shape_node
.
op
.
input_broadcastable
==
()
and
shape_node
.
op
.
new_order
==
(
"x"
,)
):
# Here PyTensor converted a tuple or list to a tensor
new_shape_args
=
JAXShapeTuple
()(
*
shape_node
.
inputs
)
new_inputs
=
list
(
node
.
inputs
)
new_inputs
[
1
]
=
new_shape_args
new_node
=
node
.
clone_with_new_inputs
(
new_inputs
)
return
new_node
.
outputs
optdb
.
register
(
"jax_shape_parameter_as_tuple"
,
in2out
(
shape_parameter_as_tuple
),
"jax"
,
position
=
100
,
)
tests/link/jax/test_shape.py
浏览文件 @
8cf3b20f
...
...
@@ -45,30 +45,34 @@ def test_jax_specify_shape():
compare_jax_and_py
(
x_fg
,
[])
def
test_jax_Reshape
():
def
test_jax_Reshape
_constant
():
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
(
2
,
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
# Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
def
test_jax_Reshape_concrete_shape
():
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
a
.
shape
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
x
=
reshape
(
a
,
(
a
.
shape
[
0
]
//
2
,
a
.
shape
[
0
]
//
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
with
pytest
.
raises
(
TypeError
,
match
=
"Shapes must be 1D sequences of concrete values of integer type"
,
):
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
b
=
iscalar
(
"b"
)
x
=
reshape
(
a
,
(
b
,
b
))
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
with
pytest
.
raises
(
TypeError
,
match
=
"Shapes must be 1D sequences of concrete values of integer type"
,
):
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
@pytest.mark.xfail
(
reason
=
"`shape_at` should be specified as a static argument"
,
strict
=
True
)
def
test_jax_Reshape_shape_graph_input
():
a
=
vector
(
"a"
)
shape_at
=
iscalar
(
"b"
)
x
=
reshape
(
a
,
(
shape_at
,
shape_at
))
x_fg
=
FunctionGraph
([
a
,
shape_at
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
def
test_jax_compile_ops
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论