Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
65967fe2
提交
65967fe2
authored
1月 05, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
1月 07, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement vectorize_node for Softmax and Argmax Ops
Also refactors shared logic for other batch axed Ops
上级
08a9ba35
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
169 行增加
和
29 行删除
+169
-29
basic.py
pytensor/tensor/basic.py
+16
-6
elemwise.py
pytensor/tensor/elemwise.py
+35
-15
math.py
pytensor/tensor/math.py
+21
-2
shape.py
pytensor/tensor/shape.py
+8
-5
special.py
pytensor/tensor/special.py
+28
-0
test_math.py
tests/tensor/test_math.py
+30
-0
test_special.py
tests/tensor/test_special.py
+31
-1
没有找到文件。
pytensor/tensor/basic.py
浏览文件 @
65967fe2
...
@@ -43,7 +43,12 @@ from pytensor.tensor import (
...
@@ -43,7 +43,12 @@ from pytensor.tensor import (
get_vector_length
,
get_vector_length
,
)
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
,
scalar_elemwise
from
pytensor.tensor.elemwise
import
(
DimShuffle
,
Elemwise
,
get_normalized_batch_axes
,
scalar_elemwise
,
)
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
Shape
,
Shape
,
...
@@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
...
@@ -3614,13 +3619,18 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
@_vectorize_node.register
(
ExtractDiag
)
@_vectorize_node.register
(
ExtractDiag
)
def
vectorize_extract_diag
(
op
:
ExtractDiag
,
node
,
batched_x
):
def
vectorize_extract_diag
(
op
:
ExtractDiag
,
node
,
batch_x
):
batched_ndims
=
batched_x
.
type
.
ndim
-
node
.
inputs
[
0
]
.
type
.
ndim
core_ndim
=
node
.
inputs
[
0
]
.
type
.
ndim
batch_ndim
=
batch_x
.
type
.
ndim
-
core_ndim
batch_axis1
,
batch_axis2
=
get_normalized_batch_axes
(
(
op
.
axis1
,
op
.
axis2
),
core_ndim
,
batch_ndim
)
return
diagonal
(
return
diagonal
(
batch
ed
_x
,
batch_x
,
offset
=
op
.
offset
,
offset
=
op
.
offset
,
axis1
=
op
.
axis1
+
batched_ndims
,
axis1
=
batch_axis1
,
axis2
=
op
.
axis2
+
batched_ndims
,
axis2
=
batch_axis2
,
)
.
owner
)
.
owner
...
...
pytensor/tensor/elemwise.py
浏览文件 @
65967fe2
from
copy
import
copy
from
copy
import
copy
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
from
numpy.core.numeric
import
normalize_axis_tuple
import
pytensor.tensor.basic
import
pytensor.tensor.basic
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
...
@@ -1399,7 +1401,7 @@ class CAReduce(COp):
...
@@ -1399,7 +1401,7 @@ class CAReduce(COp):
# scalar inputs are treated as 1D regarding axis in this `Op`
# scalar inputs are treated as 1D regarding axis in this `Op`
if
axis
is
not
None
:
if
axis
is
not
None
:
try
:
try
:
axis
=
n
p
.
core
.
numeric
.
n
ormalize_axis_tuple
(
axis
,
ndim
=
max
(
1
,
inp_dims
))
axis
=
normalize_axis_tuple
(
axis
,
ndim
=
max
(
1
,
inp_dims
))
except
np
.
AxisError
:
except
np
.
AxisError
:
raise
np
.
AxisError
(
axis
,
ndim
=
inp_dims
)
raise
np
.
AxisError
(
axis
,
ndim
=
inp_dims
)
...
@@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
...
@@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
return
DimShuffle
(
input_broadcastable
,
new_order
)
.
make_node
(
x
)
return
DimShuffle
(
input_broadcastable
,
new_order
)
.
make_node
(
x
)
@_vectorize_node.register
(
CAReduce
)
def
get_normalized_batch_axes
(
def
vectorize_careduce
(
op
:
CAReduce
,
node
:
Apply
,
x
:
TensorVariable
)
->
Apply
:
core_axes
:
Union
[
None
,
int
,
tuple
[
int
,
...
]],
batched_ndims
=
x
.
type
.
ndim
-
node
.
inputs
[
0
]
.
type
.
ndim
core_ndim
:
int
,
if
not
batched_ndims
:
batch_ndim
:
int
,
return
node
.
op
.
make_node
(
x
)
)
->
tuple
[
int
,
...
]:
axes
=
op
.
axis
"""Compute batch axes for a batched operation, from the core input ndim and axes.
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
if
axes
is
None
:
batch_axes(None, 2, 4) -> (2, 3)
axes
=
list
(
range
(
node
.
inputs
[
0
]
.
type
.
ndim
))
e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
batch_axes(0, 2, 4) -> (2,)
e.g., sum(tensor3, axis=(0, -1)) -> sum(tensor4, axis=(1, 3))
batch_axes((0, -1), 3, 4) -> (1, 3)
"""
if
core_axes
is
None
:
core_axes
=
tuple
(
range
(
core_ndim
))
else
:
else
:
axes
=
list
(
axes
)
core_axes
=
normalize_axis_tuple
(
core_axes
,
core_ndim
)
new_axes
=
[
axis
+
batched_ndims
for
axis
in
axes
]
return
tuple
(
core_axis
+
batch_ndim
for
core_axis
in
core_axes
)
new_op
=
op
.
clone
(
axis
=
new_axes
)
return
new_op
.
make_node
(
x
)
@_vectorize_node.register
(
CAReduce
)
def
vectorize_careduce
(
op
:
CAReduce
,
node
:
Apply
,
batch_x
:
TensorVariable
)
->
Apply
:
core_ndim
=
node
.
inputs
[
0
]
.
type
.
ndim
batch_ndim
=
batch_x
.
type
.
ndim
-
core_ndim
if
not
batch_ndim
:
return
node
.
op
.
make_node
(
batch_x
)
batch_axes
=
get_normalized_batch_axes
(
op
.
axis
,
core_ndim
,
batch_ndim
)
return
op
.
clone
(
axis
=
batch_axes
)
.
make_node
(
batch_x
)
pytensor/tensor/math.py
浏览文件 @
65967fe2
...
@@ -27,7 +27,13 @@ from pytensor.tensor.basic import (
...
@@ -27,7 +27,13 @@ from pytensor.tensor.basic import (
switch
,
switch
,
)
)
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
,
scalar_elemwise
from
pytensor.tensor.elemwise
import
(
CAReduce
,
DimShuffle
,
Elemwise
,
get_normalized_batch_axes
,
scalar_elemwise
,
)
from
pytensor.tensor.shape
import
shape
,
specify_broadcastable
from
pytensor.tensor.shape
import
shape
,
specify_broadcastable
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
...
@@ -134,7 +140,7 @@ class MaxAndArgmax(COp):
...
@@ -134,7 +140,7 @@ class MaxAndArgmax(COp):
_f16_ok
=
True
_f16_ok
=
True
def
__init__
(
self
,
axis
):
def
__init__
(
self
,
axis
):
assert
isinstance
(
axis
,
list
)
assert
isinstance
(
axis
,
(
tuple
,
list
)
)
self
.
axis
=
tuple
(
axis
)
self
.
axis
=
tuple
(
axis
)
def
get_params
(
self
,
node
):
def
get_params
(
self
,
node
):
...
@@ -465,6 +471,19 @@ class Argmax(COp):
...
@@ -465,6 +471,19 @@ class Argmax(COp):
return
[
x
.
zeros_like
()]
return
[
x
.
zeros_like
()]
@_vectorize_node.register
(
Argmax
)
@_vectorize_node.register
(
MaxAndArgmax
)
def
vectorize_argmax_node
(
op
,
node
,
batch_x
):
core_ndim
=
node
.
inputs
[
0
]
.
type
.
ndim
batch_ndim
=
batch_x
.
type
.
ndim
-
core_ndim
if
not
batch_ndim
:
return
node
.
op
.
make_node
(
batch_x
)
batch_axes
=
get_normalized_batch_axes
(
op
.
axis
,
core_ndim
,
batch_ndim
)
return
type
(
op
)(
axis
=
batch_axes
)
.
make_node
(
batch_x
)
def
makeKeepDims
(
x
,
y
,
axis
):
def
makeKeepDims
(
x
,
y
,
axis
):
"""
"""
Reintroduces in y with length one the axes of x which have been left out
Reintroduces in y with length one the axes of x which have been left out
...
...
pytensor/tensor/shape.py
浏览文件 @
65967fe2
...
@@ -18,6 +18,7 @@ from pytensor.scalar import int32
...
@@ -18,6 +18,7 @@ from pytensor.scalar import int32
from
pytensor.tensor
import
_get_vector_length
,
as_tensor_variable
from
pytensor.tensor
import
_get_vector_length
,
as_tensor_variable
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.elemwise
import
get_normalized_batch_axes
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
int_dtypes
,
tensor
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
int_dtypes
,
tensor
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
NoneConst
...
@@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes):
...
@@ -1103,8 +1104,10 @@ def unbroadcast(x, *axes):
@_vectorize_node.register
(
Unbroadcast
)
@_vectorize_node.register
(
Unbroadcast
)
def
_vectorize_unbroadcast
(
op
:
Unbroadcast
,
node
:
Apply
,
x
:
TensorVariable
)
->
Apply
:
def
_vectorize_unbroadcast
(
batched_ndims
=
x
.
type
.
ndim
-
node
.
inputs
[
0
]
.
type
.
ndim
op
:
Unbroadcast
,
node
:
Apply
,
batch_x
:
TensorVariable
old_axes
=
op
.
axes
)
->
Apply
:
new_axes
=
(
old_axis
+
batched_ndims
for
old_axis
in
old_axes
)
core_ndim
=
node
.
inputs
[
0
]
.
type
.
ndim
return
cast
(
Apply
,
unbroadcast
(
x
,
*
new_axes
)
.
owner
)
batch_ndim
=
batch_x
.
type
.
ndim
-
core_ndim
batch_axes
=
get_normalized_batch_axes
(
op
.
axes
,
core_ndim
,
batch_ndim
)
return
cast
(
Apply
,
unbroadcast
(
batch_x
,
*
batch_axes
)
.
owner
)
pytensor/tensor/special.py
浏览文件 @
65967fe2
...
@@ -4,8 +4,10 @@ import numpy as np
...
@@ -4,8 +4,10 @@ import numpy as np
import
scipy
import
scipy
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.op
import
COp
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.elemwise
import
get_normalized_batch_axes
from
pytensor.tensor.math
import
gamma
,
gammaln
,
neg
,
sum
from
pytensor.tensor.math
import
gamma
,
gammaln
,
neg
,
sum
...
@@ -736,6 +738,32 @@ def log_softmax(c, axis=None):
...
@@ -736,6 +738,32 @@ def log_softmax(c, axis=None):
return
LogSoftmax
(
axis
=
axis
)(
c
)
return
LogSoftmax
(
axis
=
axis
)(
c
)
@_vectorize_node.register
(
Softmax
)
@_vectorize_node.register
(
LogSoftmax
)
def
vectorize_softmax_node
(
op
,
node
,
batched_x
):
"""
Vectorize Softmax and LogSoftmax nodes.
"""
core_ndim
=
node
.
inputs
[
0
]
.
type
.
ndim
batch_ndim
=
batched_x
.
type
.
ndim
-
core_ndim
if
not
batch_ndim
:
return
op
.
make_node
(
batched_x
)
batch_axes
=
get_normalized_batch_axes
(
op
.
axis
,
core_ndim
,
batch_ndim
)
if
len
(
batch_axes
)
>
1
:
from
pytensor.tensor.blockwise
import
vectorize_node_fallback
# The softmax Ops only allow a specific axis (integer) or all axis (None).
# If the vectorized operation requires more than one axis we have to default to a Blockwise
return
vectorize_node_fallback
(
op
,
node
,
batched_x
)
[
batch_axis
]
=
batch_axes
return
type
(
op
)(
axis
=
batch_axis
)
.
make_node
(
batched_x
)
def
poch
(
z
,
m
):
def
poch
(
z
,
m
):
"""
"""
Pochhammer symbol (rising factorial) function.
Pochhammer symbol (rising factorial) function.
...
...
tests/tensor/test_math.py
浏览文件 @
65967fe2
...
@@ -20,6 +20,7 @@ from pytensor.configdefaults import config
...
@@ -20,6 +20,7 @@ from pytensor.configdefaults import config
from
pytensor.gradient
import
NullTypeGradError
,
grad
,
numeric_grad
from
pytensor.gradient
import
NullTypeGradError
,
grad
,
numeric_grad
from
pytensor.graph.basic
import
Variable
,
applys_between
from
pytensor.graph.basic
import
Variable
,
applys_between
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.link.c.basic
import
DualLinker
from
pytensor.link.c.basic
import
DualLinker
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.printing
import
pprint
from
pytensor.printing
import
pprint
...
@@ -1010,6 +1011,35 @@ class TestMaxAndArgmax:
...
@@ -1010,6 +1011,35 @@ class TestMaxAndArgmax:
assert
max_pt
.
eval
()
==
3
assert
max_pt
.
eval
()
==
3
assert
argmax_pt
.
eval
()
==
2
assert
argmax_pt
.
eval
()
==
2
@pytest.mark.parametrize
(
"core_axis, batch_axis"
,
[
(
None
,
(
1
,
2
,
3
,
4
)),
(
0
,
(
1
,)),
((
1
,
-
1
),
(
2
,
4
)),
],
)
def
test_vectorize
(
self
,
core_axis
,
batch_axis
):
x
=
tensor
(
shape
=
(
5
,
5
,
5
,
5
))
batch_x
=
tensor
(
shape
=
(
3
,
5
,
5
,
5
,
5
))
# Test MaxAndArgmax
max_x
,
argmax_x
=
max_and_argmax
(
x
,
axis
=
core_axis
)
node
=
max_x
.
owner
assert
isinstance
(
node
.
op
,
MaxAndArgmax
)
new_node
=
vectorize_node
(
node
,
batch_x
)
assert
isinstance
(
new_node
.
op
,
MaxAndArgmax
)
assert
new_node
.
op
.
axis
==
batch_axis
# Test Argmax
# Argmax is not user-facing, so we have to create it manually
node
=
Argmax
(
axis
=
node
.
op
.
axis
)
.
make_node
(
x
)
new_node
=
vectorize_node
(
node
,
batch_x
)
assert
isinstance
(
new_node
.
op
,
Argmax
)
assert
new_node
.
op
.
axis
==
batch_axis
class
TestArgminArgmax
:
class
TestArgminArgmax
:
def
setup_method
(
self
):
def
setup_method
(
self
):
...
...
tests/tensor/test_special.py
浏览文件 @
65967fe2
...
@@ -8,6 +8,8 @@ from scipy.special import softmax as scipy_softmax
...
@@ -8,6 +8,8 @@ from scipy.special import softmax as scipy_softmax
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.special
import
(
from
pytensor.tensor.special
import
(
LogSoftmax
,
LogSoftmax
,
Softmax
,
Softmax
,
...
@@ -19,7 +21,7 @@ from pytensor.tensor.special import (
...
@@ -19,7 +21,7 @@ from pytensor.tensor.special import (
poch
,
poch
,
softmax
,
softmax
,
)
)
from
pytensor.tensor.type
import
matrix
,
tensor3
,
tensor4
,
vector
,
vectors
from
pytensor.tensor.type
import
matrix
,
tensor
,
tensor
3
,
tensor4
,
vector
,
vectors
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.tensor.utils
import
random_ranged
from
tests.tensor.utils
import
random_ranged
...
@@ -150,6 +152,34 @@ class TestSoftmaxGrad(utt.InferShapeTester):
...
@@ -150,6 +152,34 @@ class TestSoftmaxGrad(utt.InferShapeTester):
SoftmaxGrad
(
-
4
)(
*
x
)
SoftmaxGrad
(
-
4
)(
*
x
)
@pytest.mark.parametrize
(
"core_axis, batch_axis"
,
[
(
None
,
(
1
,
2
,
3
,
4
)),
(
0
,
(
1
,)),
],
)
@pytest.mark.parametrize
(
"op, constructor"
,
[(
Softmax
,
softmax
),
(
LogSoftmax
,
log_softmax
)]
)
def
test_vectorize_softmax
(
op
,
constructor
,
core_axis
,
batch_axis
):
x
=
tensor
(
shape
=
(
5
,
5
,
5
,
5
))
batch_x
=
tensor
(
shape
=
(
3
,
5
,
5
,
5
,
5
))
node
=
constructor
(
x
,
axis
=
core_axis
)
.
owner
assert
isinstance
(
node
.
op
,
op
)
new_node
=
vectorize_node
(
node
,
batch_x
)
if
len
(
batch_axis
)
==
1
:
assert
isinstance
(
new_node
.
op
,
op
)
assert
(
new_node
.
op
.
axis
,)
==
batch_axis
else
:
assert
isinstance
(
new_node
.
op
,
Blockwise
)
and
isinstance
(
new_node
.
op
.
core_op
,
op
)
assert
new_node
.
op
.
core_op
.
axis
==
core_axis
def
test_poch
():
def
test_poch
():
_z
,
_m
=
vectors
(
"z"
,
"m"
)
_z
,
_m
=
vectors
(
"z"
,
"m"
)
actual_fn
=
function
([
_z
,
_m
],
poch
(
_z
,
_m
))
actual_fn
=
function
([
_z
,
_m
],
poch
(
_z
,
_m
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论