Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7eafb6c6
提交
7eafb6c6
authored
9月 24, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Vectorize shape operations
上级
7f0567a8
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
172 行增加
和
24 行删除
+172
-24
replace.py
pytensor/graph/replace.py
+5
-1
blockwise.py
pytensor/tensor/blockwise.py
+17
-15
elemwise.py
pytensor/tensor/elemwise.py
+2
-3
shape.py
pytensor/tensor/shape.py
+59
-1
test_shape.py
tests/tensor/test_shape.py
+89
-4
没有找到文件。
pytensor/graph/replace.py
浏览文件 @
7eafb6c6
...
@@ -204,7 +204,7 @@ def graph_replace(
...
@@ -204,7 +204,7 @@ def graph_replace(
@singledispatch
@singledispatch
def
_vectorize_node
(
op
:
Op
,
node
:
Apply
,
*
bached_inputs
)
->
Apply
:
def
_vectorize_node
(
op
:
Op
,
node
:
Apply
,
*
ba
t
ched_inputs
)
->
Apply
:
# Default implementation is provided in pytensor.tensor.blockwise
# Default implementation is provided in pytensor.tensor.blockwise
raise
NotImplementedError
raise
NotImplementedError
...
@@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
...
@@ -215,6 +215,10 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
return
_vectorize_node
(
op
,
node
,
*
batched_inputs
)
return
_vectorize_node
(
op
,
node
,
*
batched_inputs
)
def
_vectorize_not_needed
(
op
,
node
,
*
batched_inputs
):
return
op
.
make_node
(
*
batched_inputs
)
@overload
@overload
def
vectorize_graph
(
def
vectorize_graph
(
outputs
:
Variable
,
outputs
:
Variable
,
...
...
pytensor/tensor/blockwise.py
浏览文件 @
7eafb6c6
...
@@ -8,7 +8,11 @@ from pytensor.gradient import DisconnectedType
...
@@ -8,7 +8,11 @@ from pytensor.gradient import DisconnectedType
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
,
vectorize_graph
from
pytensor.graph.replace
import
(
_vectorize_node
,
_vectorize_not_needed
,
vectorize_graph
,
)
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor.shape
import
shape_padleft
from
pytensor.tensor.shape
import
shape_padleft
from
pytensor.tensor.type
import
continuous_dtypes
,
discrete_dtypes
,
tensor
from
pytensor.tensor.type
import
continuous_dtypes
,
discrete_dtypes
,
tensor
...
@@ -37,17 +41,6 @@ def safe_signature(
...
@@ -37,17 +41,6 @@ def safe_signature(
return
f
"{inputs_sig}->{outputs_sig}"
return
f
"{inputs_sig}->{outputs_sig}"
@_vectorize_node.register
(
Op
)
def
vectorize_node_fallback
(
op
:
Op
,
node
:
Apply
,
*
bached_inputs
)
->
Apply
:
if
hasattr
(
op
,
"gufunc_signature"
):
signature
=
op
.
gufunc_signature
else
:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature
=
safe_signature
(
node
.
inputs
,
node
.
outputs
)
return
cast
(
Apply
,
Blockwise
(
op
,
signature
=
signature
)
.
make_node
(
*
bached_inputs
))
class
Blockwise
(
Op
):
class
Blockwise
(
Op
):
"""Generalizes a core `Op` to work with batched dimensions.
"""Generalizes a core `Op` to work with batched dimensions.
...
@@ -361,6 +354,15 @@ class Blockwise(Op):
...
@@ -361,6 +354,15 @@ class Blockwise(Op):
return
self
.
name
return
self
.
name
@_vectorize_node.register
(
Blockwise
)
@_vectorize_node.register
(
Op
)
def
vectorize_not_needed
(
op
,
node
,
*
batch_inputs
):
def
vectorize_node_fallback
(
op
:
Op
,
node
:
Apply
,
*
bached_inputs
)
->
Apply
:
return
op
.
make_node
(
*
batch_inputs
)
if
hasattr
(
op
,
"gufunc_signature"
):
signature
=
op
.
gufunc_signature
else
:
# TODO: This is pretty bad for shape inference and merge optimization!
# Should get better as we add signatures to our Ops
signature
=
safe_signature
(
node
.
inputs
,
node
.
outputs
)
return
cast
(
Apply
,
Blockwise
(
op
,
signature
=
signature
)
.
make_node
(
*
bached_inputs
))
_vectorize_node
.
register
(
Blockwise
,
_vectorize_not_needed
)
pytensor/tensor/elemwise.py
浏览文件 @
7eafb6c6
...
@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
...
@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.null_type
import
NullType
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
,
_vectorize_not_needed
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.link.c.basic
import
failure_code
from
pytensor.link.c.basic
import
failure_code
from
pytensor.link.c.op
import
COp
,
ExternalCOp
,
OpenMPOp
from
pytensor.link.c.op
import
COp
,
ExternalCOp
,
OpenMPOp
...
@@ -22,7 +22,6 @@ from pytensor.scalar.basic import transfer_type, upcast
...
@@ -22,7 +22,6 @@ from pytensor.scalar.basic import transfer_type, upcast
from
pytensor.tensor
import
elemwise_cgen
as
cgen
from
pytensor.tensor
import
elemwise_cgen
as
cgen
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.basic
import
_get_vector_length
,
as_tensor_variable
from
pytensor.tensor.basic
import
_get_vector_length
,
as_tensor_variable
from
pytensor.tensor.blockwise
import
vectorize_not_needed
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
TensorType
,
TensorType
,
continuous_dtypes
,
continuous_dtypes
,
...
@@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var):
...
@@ -1741,7 +1740,7 @@ def _get_vector_length_Elemwise(op, var):
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
_vectorize_node
.
register
(
Elemwise
,
vectorize_not_needed
)
_vectorize_node
.
register
(
Elemwise
,
_
vectorize_not_needed
)
@_vectorize_node.register
(
DimShuffle
)
@_vectorize_node.register
(
DimShuffle
)
...
...
pytensor/tensor/shape.py
浏览文件 @
7eafb6c6
...
@@ -8,6 +8,7 @@ import numpy as np
...
@@ -8,6 +8,7 @@ import numpy as np
import
pytensor
import
pytensor
from
pytensor.gradient
import
DisconnectedType
from
pytensor.gradient
import
DisconnectedType
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.replace
import
_vectorize_node
,
_vectorize_not_needed
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.type
import
HasShape
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.link.c.params_type
import
ParamsType
...
@@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var):
...
@@ -154,6 +155,9 @@ def _get_vector_length_Shape(op, var):
return
var
.
owner
.
inputs
[
0
]
.
type
.
ndim
return
var
.
owner
.
inputs
[
0
]
.
type
.
ndim
_vectorize_node
.
register
(
Shape
,
_vectorize_not_needed
)
def
shape_tuple
(
x
:
TensorVariable
)
->
tuple
[
Variable
,
...
]:
def
shape_tuple
(
x
:
TensorVariable
)
->
tuple
[
Variable
,
...
]:
r"""Get a tuple of symbolic shape values.
r"""Get a tuple of symbolic shape values.
...
@@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var):
...
@@ -580,6 +584,32 @@ def _get_vector_length_SpecifyShape(op, var):
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
@_vectorize_node.register
(
SpecifyShape
)
def
_vectorize_specify_shape
(
op
,
node
,
x
,
*
shape
):
old_x
,
*
old_shape
=
node
.
inputs
batched_ndims
=
x
.
type
.
ndim
-
old_x
.
type
.
ndim
if
any
(
as_tensor_variable
(
dim
)
.
type
.
ndim
!=
0
for
dim
in
shape
if
not
(
NoneConst
.
equals
(
dim
)
or
dim
is
None
)
):
raise
NotImplementedError
(
"It is not possible to vectorize the shape argument of SpecifyShape"
)
if
len
(
shape
)
==
len
(
old_shape
):
new_shape
=
tuple
([
None
]
*
batched_ndims
)
+
shape
elif
len
(
shape
)
==
(
len
(
old_shape
)
+
batched_ndims
):
new_shape
=
shape
else
:
raise
ValueError
(
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
)
return
specify_shape
(
x
,
new_shape
)
.
owner
class
Reshape
(
COp
):
class
Reshape
(
COp
):
"""Perform a reshape operation of the input x to the new shape shp.
"""Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be
The number of dimensions to which to reshape to (ndim) must be
...
@@ -638,7 +668,7 @@ class Reshape(COp):
...
@@ -638,7 +668,7 @@ class Reshape(COp):
return
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
)])
return
Apply
(
self
,
[
x
,
shp
],
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
)])
def
perform
(
self
,
node
,
inp
,
out_
,
params
):
def
perform
(
self
,
node
,
inp
,
out_
,
params
=
None
):
x
,
shp
=
inp
x
,
shp
=
inp
(
out
,)
=
out_
(
out
,)
=
out_
if
len
(
shp
)
!=
self
.
ndim
:
if
len
(
shp
)
!=
self
.
ndim
:
...
@@ -770,6 +800,26 @@ class Reshape(COp):
...
@@ -770,6 +800,26 @@ class Reshape(COp):
"""
"""
@_vectorize_node.register
(
Reshape
)
def
_vectorize_reshape
(
op
,
node
,
x
,
shape
):
old_x
,
old_shape
=
node
.
inputs
batched_ndims
=
x
.
type
.
ndim
-
old_x
.
type
.
ndim
if
as_tensor_variable
(
shape
)
.
type
.
ndim
!=
1
:
raise
NotImplementedError
(
"It is not possible to vectorize the shape argument of Reshape"
)
if
len
(
tuple
(
old_shape
))
==
len
(
tuple
(
shape
)):
new_shape
=
[
*
x
.
shape
[:
batched_ndims
],
*
shape
]
elif
len
(
tuple
(
old_shape
))
==
(
len
(
tuple
(
shape
))
-
batched_ndims
):
new_shape
=
shape
else
:
raise
ValueError
(
"Invalid shape length passed into vectorize node of Reshape"
)
return
reshape
(
x
,
new_shape
,
ndim
=
len
(
new_shape
))
.
owner
def
reshape
(
x
,
newshape
,
ndim
=
None
):
def
reshape
(
x
,
newshape
,
ndim
=
None
):
if
ndim
is
None
:
if
ndim
is
None
:
newshape
=
at
.
as_tensor_variable
(
newshape
)
newshape
=
at
.
as_tensor_variable
(
newshape
)
...
@@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes):
...
@@ -1034,3 +1084,11 @@ def unbroadcast(x, *axes):
if
not
unbroadcasted_axes
:
if
not
unbroadcasted_axes
:
return
x
return
x
return
Unbroadcast
(
*
unbroadcasted_axes
)(
x
)
return
Unbroadcast
(
*
unbroadcasted_axes
)(
x
)
@_vectorize_node.register
(
Unbroadcast
)
def
_vectorize_unbroadcast
(
op
:
Unbroadcast
,
node
:
Apply
,
x
:
TensorVariable
)
->
Apply
:
batched_ndims
=
x
.
type
.
ndim
-
node
.
inputs
[
0
]
.
type
.
ndim
old_axes
=
op
.
axes
new_axes
=
(
old_axis
+
batched_ndims
for
old_axis
in
old_axes
)
return
unbroadcast
(
x
,
*
new_axes
)
.
owner
tests/tensor/test_shape.py
浏览文件 @
7eafb6c6
...
@@ -5,14 +5,14 @@ import pytensor
...
@@ -5,14 +5,14 @@ import pytensor
from
pytensor
import
Mode
,
function
,
grad
from
pytensor
import
Mode
,
function
,
grad
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.basic
import
Variable
,
equal_computations
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
,
vectorize_node
from
pytensor.graph.type
import
Type
from
pytensor.graph.type
import
Type
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.scalar.basic
import
ScalarConstant
from
pytensor.scalar.basic
import
ScalarConstant
from
pytensor.tensor
import
as_tensor_variable
,
get_vector_length
,
row
from
pytensor.tensor
import
as_tensor_variable
,
broadcast_to
,
get_vector_length
,
row
from
pytensor.tensor.basic
import
MakeVector
,
constant
from
pytensor.tensor.basic
import
MakeVector
,
as_tensor
,
constant
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.rewriting.shape
import
ShapeFeature
from
pytensor.tensor.rewriting.shape
import
ShapeFeature
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
...
@@ -706,3 +706,88 @@ def test_shape_tuple():
...
@@ -706,3 +706,88 @@ def test_shape_tuple():
assert
isinstance
(
res
[
1
],
ScalarConstant
)
assert
isinstance
(
res
[
1
],
ScalarConstant
)
assert
res
[
1
]
.
data
==
2
assert
res
[
1
]
.
data
==
2
assert
not
isinstance
(
res
[
2
],
ScalarConstant
)
assert
not
isinstance
(
res
[
2
],
ScalarConstant
)
class
TestVectorize
:
def
test_shape
(
self
):
vec
=
tensor
(
shape
=
(
None
,))
mat
=
tensor
(
shape
=
(
None
,
None
))
node
=
shape
(
vec
)
.
owner
vect_node
=
vectorize_node
(
node
,
mat
)
assert
equal_computations
(
vect_node
.
outputs
,
[
shape
(
mat
)])
def
test_reshape
(
self
):
x
=
scalar
(
"x"
,
dtype
=
int
)
vec
=
tensor
(
shape
=
(
None
,))
mat
=
tensor
(
shape
=
(
None
,
None
))
shape
=
(
2
,
x
)
node
=
reshape
(
vec
,
shape
)
.
owner
vect_node
=
vectorize_node
(
node
,
mat
,
shape
)
assert
equal_computations
(
vect_node
.
outputs
,
[
reshape
(
mat
,
(
*
mat
.
shape
[:
1
],
2
,
x
))]
)
new_shape
=
(
5
,
2
,
x
)
vect_node
=
vectorize_node
(
node
,
mat
,
new_shape
)
assert
equal_computations
(
vect_node
.
outputs
,
[
reshape
(
mat
,
new_shape
)])
with
pytest
.
raises
(
NotImplementedError
):
vectorize_node
(
node
,
vec
,
broadcast_to
(
as_tensor
([
5
,
2
,
x
]),
(
2
,
3
)))
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid shape length passed into vectorize node of Reshape"
,
):
vectorize_node
(
node
,
vec
,
(
5
,
2
,
x
))
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid shape length passed into vectorize node of Reshape"
,
):
vectorize_node
(
node
,
mat
,
(
5
,
3
,
2
,
x
))
def
test_specify_shape
(
self
):
x
=
scalar
(
"x"
,
dtype
=
int
)
mat
=
tensor
(
shape
=
(
None
,
None
))
tns
=
tensor
(
shape
=
(
None
,
None
,
None
))
shape
=
(
x
,
None
)
node
=
specify_shape
(
mat
,
shape
)
.
owner
vect_node
=
vectorize_node
(
node
,
tns
,
*
shape
)
assert
equal_computations
(
vect_node
.
outputs
,
[
specify_shape
(
tns
,
(
None
,
x
,
None
))]
)
new_shape
=
(
5
,
2
,
x
)
vect_node
=
vectorize_node
(
node
,
tns
,
*
new_shape
)
assert
equal_computations
(
vect_node
.
outputs
,
[
specify_shape
(
tns
,
(
5
,
2
,
x
))])
with
pytest
.
raises
(
NotImplementedError
):
vectorize_node
(
node
,
mat
,
*
([
x
,
x
],
None
))
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
,
):
vectorize_node
(
node
,
mat
,
*
(
5
,
2
,
x
))
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
,
):
vectorize_node
(
node
,
tns
,
*
(
5
,
3
,
2
,
x
))
def
test_unbroadcast
(
self
):
mat
=
tensor
(
shape
=
(
1
,
1
,
)
)
tns
=
tensor
(
shape
=
(
4
,
1
,
1
,
1
))
node
=
unbroadcast
(
mat
,
0
)
.
owner
vect_node
=
vectorize_node
(
node
,
tns
)
assert
equal_computations
(
vect_node
.
outputs
,
[
unbroadcast
(
tns
,
2
)])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论