Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
071eadd8
提交
071eadd8
authored
9月 23, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
9月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Matmul Operator in favor of Blockwise Dot
上级
7c58661b
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
62 行增加
和
184 行删除
+62
-184
math.py
pytensor/tensor/math.py
+19
-89
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+9
-0
variable.py
pytensor/tensor/variable.py
+8
-4
test_math.py
tests/tensor/test_math.py
+7
-86
test_variable.py
tests/tensor/test_variable.py
+19
-5
没有找到文件。
pytensor/tensor/math.py
浏览文件 @
071eadd8
...
@@ -25,11 +25,11 @@ from pytensor.tensor.basic import (
...
@@ -25,11 +25,11 @@ from pytensor.tensor.basic import (
stack
,
stack
,
switch
,
switch
,
)
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
,
scalar_elemwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
,
scalar_elemwise
from
pytensor.tensor.shape
import
shape
,
specify_broadcastable
from
pytensor.tensor.shape
import
shape
,
specify_broadcastable
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
TensorType
,
complex_dtypes
,
complex_dtypes
,
continuous_dtypes
,
continuous_dtypes
,
discrete_dtypes
,
discrete_dtypes
,
...
@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
...
@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
class
MatMul
(
Op
):
_matrix_matrix_matmul
=
Blockwise
(
_dot
,
signature
=
"(n,k),(k,m)->(n,m)"
)
__props__
=
(
"dtype"
,)
def
__init__
(
self
,
dtype
=
None
):
self
.
dtype
=
dtype
@classmethod
def
_get_output_shape
(
cls
,
x1
,
x2
,
shapes
,
validate
=
False
):
x1_shape
,
x2_shape
=
shapes
if
x1
.
ndim
==
1
and
x2
.
ndim
==
1
:
if
validate
and
x1_shape
[
0
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"1d inputs must have the same length."
)
return
()
elif
x1
.
ndim
==
1
and
x2
.
ndim
>
1
:
if
validate
and
x1_shape
[
0
]
!=
x2_shape
[
-
2
]:
raise
ValueError
(
"length of input 1 must be equal the length "
"of the 2nd-last dimension of input 2"
)
return
x2_shape
[:
-
2
]
+
x2_shape
[
-
1
:]
elif
x1
.
ndim
>
1
and
x2
.
ndim
==
1
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"length of input 2 must be equal the length "
"of the last dimension of input 1"
)
return
x1_shape
[:
-
1
]
elif
x1
.
ndim
==
2
and
x2
.
ndim
==
2
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"number of columns of input 1 must be equal to "
"the number of rows of input 2"
)
return
x1_shape
[:
-
1
]
+
x2_shape
[
-
1
:]
elif
x1
.
ndim
>
2
and
x2
.
ndim
==
2
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"number of rows of input 2 must be equal to "
"the length of the last dimension of input 1"
)
return
x1_shape
[:
-
2
]
+
x1_shape
[
-
2
:
-
1
]
+
x2_shape
[
-
1
:]
elif
x1
.
ndim
==
2
and
x2
.
ndim
>
2
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
-
2
]:
raise
ValueError
(
"number of columns of input 1 must be equal "
"the length of the 2nd-last dimension of input 2"
)
return
x2_shape
[:
-
2
]
+
x1_shape
[
-
2
:
-
1
]
+
x2_shape
[
-
1
:]
else
:
if
validate
:
from
pytensor.tensor.random.basic
import
broadcast_shapes
bshape
=
broadcast_shapes
(
x1_shape
[:
-
2
],
x2_shape
[:
-
2
])
if
x1_shape
[
-
1
]
!=
x2_shape
[
-
2
]:
raise
ValueError
(
"length of the last dimension of input 1 must be equal "
"to the length of the 2nd-last dimension of input 2"
)
else
:
from
pytensor.tensor.extra_ops
import
broadcast_shape
bshape
=
broadcast_shape
(
x1_shape
[:
-
2
],
x2_shape
[:
-
2
],
arrays_are_shapes
=
True
)
return
bshape
+
x1_shape
[
-
2
:
-
1
]
+
x2_shape
[
-
1
:]
def
make_node
(
self
,
a
,
b
):
a
=
as_tensor_variable
(
a
)
b
=
as_tensor_variable
(
b
)
if
0
in
{
a
.
ndim
,
b
.
ndim
}:
raise
ValueError
(
"inputs to `matmul` cannot be scalar."
)
out_shape
=
self
.
_get_output_shape
(
a
,
b
,
(
a
.
type
.
shape
,
b
.
type
.
shape
),
validate
=
True
)
out
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
out_shape
)()
return
Apply
(
self
,
[
a
,
b
],
[
out
])
def
perform
(
self
,
node
,
inputs
,
outputs
):
x1
,
x2
=
inputs
outputs
[
0
][
0
]
=
np
.
matmul
(
x1
,
x2
,
dtype
=
self
.
dtype
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
x1
,
x2
=
node
.
inputs
return
[
self
.
_get_output_shape
(
x1
,
x2
,
shapes
)]
def
matmul
(
x1
:
"ArrayLike"
,
x2
:
"ArrayLike"
,
dtype
:
Optional
[
"DTypeLike"
]
=
None
):
def
matmul
(
x1
:
"ArrayLike"
,
x2
:
"ArrayLike"
,
dtype
:
Optional
[
"DTypeLike"
]
=
None
):
...
@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
...
@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
- Stacks of matrices are broadcast together as if the matrices were elements,
- Stacks of matrices are broadcast together as if the matrices were elements,
respecting the signature ``(n, k), (k, m) -> (n, m)``:
respecting the signature ``(n, k), (k, m) -> (n, m)``:
"""
"""
return
MatMul
(
dtype
=
dtype
)(
x1
,
x2
)
x1
=
as_tensor_variable
(
x1
)
x2
=
as_tensor_variable
(
x2
)
if
x1
.
type
.
ndim
==
0
or
x2
.
type
.
ndim
==
0
:
raise
ValueError
(
"matmul operand cannot be scalar"
)
if
x1
.
type
.
ndim
==
1
and
x2
.
type
.
ndim
==
1
:
out
=
_dot
(
x1
,
x2
)
elif
x1
.
type
.
ndim
==
1
:
out
=
_matrix_matrix_matmul
(
x1
[
None
],
x2
)
.
squeeze
(
-
2
)
elif
x2
.
type
.
ndim
==
1
:
out
=
_matrix_matrix_matmul
(
x1
,
x2
[:,
None
])
.
squeeze
(
-
1
)
else
:
out
=
_matrix_matrix_matmul
(
x1
,
x2
)
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
return
out
__all__
=
[
__all__
=
[
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
071eadd8
...
@@ -3,6 +3,8 @@ from pytensor.graph import node_rewriter
...
@@ -3,6 +3,8 @@ from pytensor.graph import node_rewriter
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
_matrix_matrix_matmul
from
pytensor.tensor.rewriting.basic
import
register_canonicalize
@node_rewriter
([
Blockwise
])
@node_rewriter
([
Blockwise
])
...
@@ -40,3 +42,10 @@ optdb.register(
...
@@ -40,3 +42,10 @@ optdb.register(
"blockwise"
,
"blockwise"
,
position
=
49
,
position
=
49
,
)
)
# Avoid redundant cases early on for Ops whose default form is not Blockwised
@register_canonicalize
@node_rewriter
(
tracks
=
[
_matrix_matrix_matmul
])
def
local_eager_useless_unbatched_blockwise
(
fgraph
,
node
):
return
local_useless_unbatched_blockwise
.
fn
(
fgraph
,
node
)
pytensor/tensor/variable.py
浏览文件 @
071eadd8
...
@@ -647,8 +647,12 @@ class _tensor_py_operators:
...
@@ -647,8 +647,12 @@ class _tensor_py_operators:
return
at
.
math
.
dense_dot
(
left
,
right
)
return
at
.
math
.
dense_dot
(
left
,
right
)
dot
=
__dot__
dot
=
__dot__
__matmul__
=
__dot__
__rmatmul__
=
__rdot__
def
__matmul__
(
left
,
right
):
return
at
.
math
.
matmul
(
left
,
right
)
def
__rmatmul__
(
right
,
left
):
return
at
.
math
.
matmul
(
right
,
left
)
def
sum
(
self
,
axis
=
None
,
dtype
=
None
,
keepdims
=
False
,
acc_dtype
=
None
):
def
sum
(
self
,
axis
=
None
,
dtype
=
None
,
keepdims
=
False
,
acc_dtype
=
None
):
"""See :func:`pytensor.tensor.math.sum`."""
"""See :func:`pytensor.tensor.math.sum`."""
...
@@ -797,7 +801,7 @@ class _tensor_py_operators:
...
@@ -797,7 +801,7 @@ class _tensor_py_operators:
"""
"""
return
at
.
basic
.
choose
(
self
,
choices
,
mode
=
"raise"
)
return
at
.
basic
.
choose
(
self
,
choices
,
mode
=
"raise"
)
def
squeeze
(
self
):
def
squeeze
(
self
,
axis
=
None
):
"""
"""
Remove broadcastable dimensions from the shape of an array.
Remove broadcastable dimensions from the shape of an array.
...
@@ -805,7 +809,7 @@ class _tensor_py_operators:
...
@@ -805,7 +809,7 @@ class _tensor_py_operators:
removed. This is always `x` itself or a view into `x`.
removed. This is always `x` itself or a view into `x`.
"""
"""
return
at
.
extra_ops
.
squeeze
(
self
)
return
at
.
extra_ops
.
squeeze
(
self
,
axis
=
axis
)
def
compress
(
self
,
a
,
axis
=
None
):
def
compress
(
self
,
a
,
axis
=
None
):
"""Return selected slices only."""
"""Return selected slices only."""
...
...
tests/tensor/test_math.py
浏览文件 @
071eadd8
...
@@ -30,11 +30,11 @@ from pytensor.tensor.basic import (
...
@@ -30,11 +30,11 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value
,
get_underlying_scalar_constant_value
,
switch
,
switch
,
)
)
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.elemwise
import
CAReduce
,
Elemwise
from
pytensor.tensor.elemwise
import
CAReduce
,
Elemwise
from
pytensor.tensor.math
import
(
from
pytensor.tensor.math
import
(
Argmax
,
Argmax
,
Dot
,
Dot
,
MatMul
,
MaxAndArgmax
,
MaxAndArgmax
,
Mean
,
Mean
,
Prod
,
Prod
,
...
@@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim():
...
@@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim():
assert
grad_x_fn
(
-
1e-308
)
!=
-
np
.
inf
assert
grad_x_fn
(
-
1e-308
)
!=
-
np
.
inf
class
TestMatMul
(
utt
.
InferShapeTester
)
:
class
TestMatMul
:
def
setup_method
(
self
):
def
setup_method
(
self
):
super
()
.
setup_method
()
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
self
.
op
=
matmul
self
.
op
=
matmul
self
.
op_class
=
MatMul
def
_validate_output
(
self
,
a
,
b
):
def
_validate_output
(
self
,
a
,
b
):
pytensor_sol
=
self
.
op
(
a
,
b
)
.
eval
()
pytensor_sol
=
self
.
op
(
a
,
b
)
.
eval
()
...
@@ -3467,85 +3465,8 @@ class TestMatMul(utt.InferShapeTester):
...
@@ -3467,85 +3465,8 @@ class TestMatMul(utt.InferShapeTester):
sol
=
self
.
op
([
1
,
2
,
3
],
[
3
,
2
,
1
],
dtype
=
dtype
)
sol
=
self
.
op
([
1
,
2
,
3
],
[
3
,
2
,
1
],
dtype
=
dtype
)
assert
sol
.
eval
()
.
dtype
==
dtype
assert
sol
.
eval
()
.
dtype
==
dtype
@pytest.mark.parametrize
(
def
test_dot22_opt
(
self
):
"x1_shape,x2_shape,exp_res,error_regex"
,
x
,
y
=
matrices
(
"xy"
)
[
fn
=
function
([
x
,
y
],
x
@
y
,
mode
=
"FAST_RUN"
)
((
1
,),
(
3
,),
None
,
"inputs must have the same length"
),
[
node
]
=
fn
.
maker
.
fgraph
.
apply_nodes
((
2
,),
(
3
,
1
),
None
,
"length of input 1.*2nd-last dimension of input 2"
),
assert
isinstance
(
node
.
op
,
Dot22
)
((
2
,
5
),
(
3
,),
None
,
"length of input 2.*of the last dimension of input 1"
),
(
(
2
,
5
),
(
3
,
4
),
None
,
"number of columns of input 1 .* number of rows of input 2"
,
),
(
(
2
,
1
,
3
),
(
5
,
4
),
None
,
"number of rows of input 2 .* last dimension of input 1"
,
),
(
(
2
,
5
),
(
2
,
4
,
3
),
None
,
"number of columns of input 1 .* 2nd-last dimension of input 2"
,
),
(
(
3
,
2
,
4
,
5
),
(
1
,
6
,
7
),
None
,
"length of the last dimension of input 1 .* 2nd-last dimension of input 2"
,
),
(
(
4
,
5
,
4
),
(
3
,
2
,
2
),
None
,
"cannot be broadcast to a single shape"
,
),
(
(
4
,
None
,
2
),
(
4
,
2
,
None
),
(
4
,
None
,
None
),
None
,
),
],
)
def
test_get_output_shape
(
self
,
x1_shape
,
x2_shape
,
exp_res
,
error_regex
):
x1
=
tensor
(
dtype
=
np
.
float64
,
shape
=
x1_shape
)
x2
=
tensor
(
dtype
=
np
.
float64
,
shape
=
x2_shape
)
if
error_regex
is
not
None
:
with
pytest
.
raises
(
ValueError
,
match
=
error_regex
):
self
.
op_class
.
_get_output_shape
(
x1
,
x2
,
(
x1_shape
,
x2_shape
),
validate
=
True
)
else
:
assert
(
self
.
op_class
.
_get_output_shape
(
x1
,
x2
,
(
x1_shape
,
x2_shape
),
validate
=
True
)
==
exp_res
)
def
test_infer_shape
(
self
):
for
shape_x1
,
shape_x2
in
[
((
5
,),
(
5
,)),
((
5
,),
(
2
,
5
,
3
)),
((
2
,
5
,
3
),
(
3
,)),
((
2
,
5
),
(
5
,
4
)),
((
2
,
5
),
(
2
,
5
,
3
)),
((
2
,
1
,
3
),
(
3
,
4
)),
((
3
,
2
,
4
,
5
),
(
1
,
5
,
7
)),
]:
a
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
shape_x1
)
b
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
shape_x2
)
x1
=
self
.
rng
.
random
(
shape_x1
)
.
astype
(
config
.
floatX
)
x2
=
self
.
rng
.
random
(
shape_x2
)
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
(
[
a
,
b
],
[
self
.
op
(
a
,
b
)],
[
x1
,
x2
],
self
.
op_class
,
)
tests/tensor/test_variable.py
浏览文件 @
071eadd8
...
@@ -10,9 +10,9 @@ from pytensor.compile import DeepCopyOp
...
@@ -10,9 +10,9 @@ from pytensor.compile import DeepCopyOp
from
pytensor.compile.mode
import
get_default_mode
from
pytensor.compile.mode
import
get_default_mode
from
pytensor.graph.basic
import
Constant
,
equal_computations
from
pytensor.graph.basic
import
Constant
,
equal_computations
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.basic
import
as_tensor
,
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
dot
,
eq
from
pytensor.tensor.math
import
dot
,
eq
,
matmul
from
pytensor.tensor.shape
import
Shape
from
pytensor.tensor.shape
import
Shape
from
pytensor.tensor.subtensor
import
AdvancedSubtensor
,
Subtensor
from
pytensor.tensor.subtensor
import
AdvancedSubtensor
,
Subtensor
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
...
@@ -79,16 +79,30 @@ def test_infix_dot_method():
...
@@ -79,16 +79,30 @@ def test_infix_dot_method():
X
=
dmatrix
(
"X"
)
X
=
dmatrix
(
"X"
)
y
=
dvector
(
"y"
)
y
=
dvector
(
"y"
)
res
=
X
@
y
res
=
X
.
dot
(
y
)
exp_res
=
X
.
dot
(
y
)
exp_res
=
dot
(
X
,
y
)
assert
equal_computations
([
res
],
[
exp_res
])
assert
equal_computations
([
res
],
[
exp_res
])
X_val
=
np
.
arange
(
2
*
3
)
.
reshape
((
2
,
3
))
X_val
=
np
.
arange
(
2
*
3
)
.
reshape
((
2
,
3
))
res
=
X_val
@
y
res
=
as_tensor
(
X_val
)
.
dot
(
y
)
exp_res
=
dot
(
X_val
,
y
)
exp_res
=
dot
(
X_val
,
y
)
assert
equal_computations
([
res
],
[
exp_res
])
assert
equal_computations
([
res
],
[
exp_res
])
def
test_infix_matmul_method
():
X
=
dmatrix
(
"X"
)
y
=
dvector
(
"y"
)
res
=
X
@
y
exp_res
=
matmul
(
X
,
y
)
assert
equal_computations
([
res
],
[
exp_res
])
X_val
=
np
.
arange
(
2
*
3
)
.
reshape
((
2
,
3
))
res
=
as_tensor
(
X_val
)
@
y
exp_res
=
matmul
(
X_val
,
y
)
assert
equal_computations
([
res
],
[
exp_res
])
def
test_empty_list_indexing
():
def
test_empty_list_indexing
():
ynp
=
np
.
zeros
((
2
,
2
))[:,
[]]
ynp
=
np
.
zeros
((
2
,
2
))[:,
[]]
znp
=
np
.
zeros
((
2
,
2
))[:,
()]
znp
=
np
.
zeros
((
2
,
2
))[:,
()]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论