Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
22465044
提交
22465044
authored
1月 09, 2022
作者:
Zolisa Bleki
提交者:
Brandon T. Willard
7月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add an Op for numpy.matmul
上级
bbb8bddb
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
283 行增加
和
4 行删除
+283
-4
math.py
aesara/tensor/math.py
+141
-0
nlinalg.py
aesara/tensor/nlinalg.py
+0
-4
test_math.py
tests/tensor/test_math.py
+142
-0
没有找到文件。
aesara/tensor/math.py
浏览文件 @
22465044
import
builtins
import
builtins
import
warnings
import
warnings
from
typing
import
TYPE_CHECKING
,
Optional
import
numpy
as
np
import
numpy
as
np
...
@@ -34,6 +35,7 @@ from aesara.tensor.elemwise import (
...
@@ -34,6 +35,7 @@ from aesara.tensor.elemwise import (
from
aesara.tensor.shape
import
shape
,
specify_broadcastable
from
aesara.tensor.shape
import
shape
,
specify_broadcastable
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
TensorType
,
complex_dtypes
,
complex_dtypes
,
continuous_dtypes
,
continuous_dtypes
,
discrete_dtypes
,
discrete_dtypes
,
...
@@ -47,6 +49,9 @@ from aesara.tensor.utils import as_list
...
@@ -47,6 +49,9 @@ from aesara.tensor.utils import as_list
from
aesara.tensor.var
import
TensorConstant
,
_tensor_py_operators
from
aesara.tensor.var
import
TensorConstant
,
_tensor_py_operators
if
TYPE_CHECKING
:
from
numpy.typing
import
ArrayLike
,
DTypeLike
# We capture the builtins that we are going to replace to follow the numpy API
# We capture the builtins that we are going to replace to follow the numpy API
_abs
=
builtins
.
abs
_abs
=
builtins
.
abs
...
@@ -2851,9 +2856,145 @@ def logsumexp(x, axis=None, keepdims=False):
...
@@ -2851,9 +2856,145 @@ def logsumexp(x, axis=None, keepdims=False):
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
class
MatMul
(
Op
):
__props__
=
(
"dtype"
,)
def
__init__
(
self
,
dtype
=
None
):
self
.
dtype
=
dtype
@classmethod
def
_get_output_shape
(
cls
,
x1
,
x2
,
shapes
,
validate
=
False
):
x1_shape
,
x2_shape
=
shapes
if
x1
.
ndim
==
1
and
x2
.
ndim
==
1
:
if
validate
and
x1_shape
[
0
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"1d inputs must have the same length."
)
return
()
elif
x1
.
ndim
==
1
and
x2
.
ndim
>
1
:
if
validate
and
x1_shape
[
0
]
!=
x2_shape
[
-
2
]:
raise
ValueError
(
"length of input 1 must be equal the length "
"of the 2nd-last dimension of input 2"
)
return
x2_shape
[:
-
2
]
+
x2_shape
[
-
1
:]
elif
x1
.
ndim
>
1
and
x2
.
ndim
==
1
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"length of input 2 must be equal the length "
"of the last dimension of input 1"
)
return
x1_shape
[:
-
1
]
elif
x1
.
ndim
==
2
and
x2
.
ndim
==
2
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"number of columns of input 1 must be equal to "
"the number of rows of input 2"
)
return
x1_shape
[:
-
1
]
+
x2_shape
[
-
1
:]
elif
x1
.
ndim
>
2
and
x2
.
ndim
==
2
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
0
]:
raise
ValueError
(
"number of rows of input 2 must be equal to "
"the length of the last dimension of input 1"
)
return
x1_shape
[:
-
2
]
+
x1_shape
[
-
2
:
-
1
]
+
x2_shape
[
-
1
:]
elif
x1
.
ndim
==
2
and
x2
.
ndim
>
2
:
if
validate
and
x1_shape
[
-
1
]
!=
x2_shape
[
-
2
]:
raise
ValueError
(
"number of columns of input 1 must be equal "
"the length of the 2nd-last dimension of input 2"
)
return
x2_shape
[:
-
2
]
+
x1_shape
[
-
2
:
-
1
]
+
x2_shape
[
-
1
:]
else
:
if
validate
:
from
aesara.tensor.random.basic
import
broadcast_shapes
bshape
=
broadcast_shapes
(
x1_shape
[:
-
2
],
x2_shape
[:
-
2
])
if
x1_shape
[
-
1
]
!=
x2_shape
[
-
2
]:
raise
ValueError
(
"length of the last dimension of input 1 must be equal "
"to the length of the 2nd-last dimension of input 2"
)
else
:
from
aesara.tensor.extra_ops
import
broadcast_shape
bshape
=
broadcast_shape
(
x1_shape
[:
-
2
],
x2_shape
[:
-
2
],
arrays_are_shapes
=
True
)
return
bshape
+
x1_shape
[
-
2
:
-
1
]
+
x2_shape
[
-
1
:]
def
make_node
(
self
,
a
,
b
):
a
=
as_tensor_variable
(
a
)
b
=
as_tensor_variable
(
b
)
if
0
in
{
a
.
ndim
,
b
.
ndim
}:
raise
ValueError
(
"inputs to `matmul` cannot be scalar."
)
out_shape
=
self
.
_get_output_shape
(
a
,
b
,
(
a
.
type
.
shape
,
b
.
type
.
shape
),
validate
=
True
)
out
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
out_shape
)()
return
Apply
(
self
,
[
a
,
b
],
[
out
])
def
perform
(
self
,
node
,
inputs
,
outputs
):
x1
,
x2
=
inputs
outputs
[
0
][
0
]
=
np
.
matmul
(
x1
,
x2
,
dtype
=
self
.
dtype
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
x1
,
x2
=
node
.
inputs
return
[
self
.
_get_output_shape
(
x1
,
x2
,
shapes
)]
def
matmul
(
x1
:
"ArrayLike"
,
x2
:
"ArrayLike"
,
dtype
:
Optional
[
"DTypeLike"
]
=
None
):
"""Compute the matrix product of two tensor variables.
Parameters
----------
x1, x2
Input arrays, scalars not allowed.
dtype
The desired data-type for the array. If not given, then the type will
be determined as the minimum type required to hold the objects in the
sequence.
Returns
-------
out : ndarray
The matrix product of the inputs. This is a scalar only when both
`x1`, `x2` are 1-d vectors.
Raises
------
ValueError
If the last dimension of `x1` is not the same size as the second-to-last
dimension of `x2`. If a scalar value is passed in.
Notes
-----
The behavior depends on the arguments in the following way.
- If both arguments are 2-D they are multiplied like conventional matrices.
- If either argument is N-D, N > 2, it is treated as a stack of matrices
residing in the last two indexes and broadcast accordingly.
- If the first argument is 1-D, it is promoted to a matrix by prepending a
1 to its dimensions. After matrix multiplication the prepended 1 is removed.
- If the second argument is 1-D, it is promoted to a matrix by appending a
1 to its dimensions. After matrix multiplication the appended 1 is removed.
`matmul` differs from `dot` in two important ways:
- Multiplication by scalars is not allowed, use `mul` instead.
- Stacks of matrices are broadcast together as if the matrices were elements,
respecting the signature ``(n, k), (k, m) -> (n, m)``:
"""
return
MatMul
(
dtype
=
dtype
)(
x1
,
x2
)
__all__
=
[
__all__
=
[
"max_and_argmax"
,
"max_and_argmax"
,
"max"
,
"max"
,
"matmul"
,
"argmax"
,
"argmax"
,
"min"
,
"min"
,
"argmin"
,
"argmin"
,
...
...
aesara/tensor/nlinalg.py
浏览文件 @
22465044
import
logging
from
functools
import
partial
from
functools
import
partial
from
typing
import
Tuple
,
Union
from
typing
import
Tuple
,
Union
...
@@ -14,9 +13,6 @@ from aesara.tensor.basic import as_tensor_variable, extract_diag
...
@@ -14,9 +13,6 @@ from aesara.tensor.basic import as_tensor_variable, extract_diag
from
aesara.tensor.type
import
dvector
,
lscalar
,
matrix
,
scalar
,
vector
from
aesara.tensor.type
import
dvector
,
lscalar
,
matrix
,
scalar
,
vector
logger
=
logging
.
getLogger
(
__name__
)
class
MatrixPinv
(
Op
):
class
MatrixPinv
(
Op
):
__props__
=
(
"hermitian"
,)
__props__
=
(
"hermitian"
,)
...
...
tests/tensor/test_math.py
浏览文件 @
22465044
...
@@ -35,11 +35,13 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
...
@@ -35,11 +35,13 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
from
aesara.tensor.math
import
(
from
aesara.tensor.math
import
(
Argmax
,
Argmax
,
Dot
,
Dot
,
MatMul
,
MaxAndArgmax
,
MaxAndArgmax
,
Mean
,
Mean
,
Prod
,
Prod
,
ProdWithoutZeros
,
ProdWithoutZeros
,
Sum
,
Sum
,
_allclose
,
_dot
,
_dot
,
abs
,
abs
,
add
,
add
,
...
@@ -80,6 +82,7 @@ from aesara.tensor.math import (
...
@@ -80,6 +82,7 @@ from aesara.tensor.math import (
log10
,
log10
,
logaddexp
,
logaddexp
,
logsumexp
,
logsumexp
,
matmul
,
max
,
max
,
max_and_argmax
,
max_and_argmax
,
maximum
,
maximum
,
...
@@ -3382,3 +3385,142 @@ def test_log1mexp_grad_lim():
...
@@ -3382,3 +3385,142 @@ def test_log1mexp_grad_lim():
assert
grad_x_fn
(
-
0.0
)
==
-
np
.
inf
assert
grad_x_fn
(
-
0.0
)
==
-
np
.
inf
assert
grad_x_fn
(
-
1e-309
)
==
-
np
.
inf
assert
grad_x_fn
(
-
1e-309
)
==
-
np
.
inf
assert
grad_x_fn
(
-
1e-308
)
!=
-
np
.
inf
assert
grad_x_fn
(
-
1e-308
)
!=
-
np
.
inf
class
TestMatMul
(
utt
.
InferShapeTester
):
def
setup_method
(
self
):
super
()
.
setup_method
()
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
self
.
op
=
matmul
self
.
op_class
=
MatMul
def
_validate_output
(
self
,
a
,
b
):
aesara_sol
=
self
.
op
(
a
,
b
)
.
eval
()
numpy_sol
=
np
.
matmul
(
a
,
b
)
assert
_allclose
(
numpy_sol
,
aesara_sol
)
@pytest.mark.parametrize
(
"x1, x2"
,
[
# test output when both inputs are vectors
(
np
.
arange
(
3
)
.
astype
(
config
.
floatX
),
np
.
arange
(
3
)
.
astype
(
config
.
floatX
)),
# test output when both inputs are matrices
(
np
.
arange
(
3
*
5
)
.
reshape
((
5
,
3
))
.
astype
(
config
.
floatX
),
np
.
arange
(
2
*
3
)
.
reshape
((
3
,
2
))
.
astype
(
config
.
floatX
),
),
# test behaviour when one of the inputs is has dimension > 2
(
np
.
arange
(
3
*
5
)
.
reshape
((
5
,
3
))
.
astype
(
config
.
floatX
),
np
.
arange
(
2
*
3
*
5
)
.
reshape
((
2
,
3
,
5
))
.
astype
(
config
.
floatX
),
),
# test behaviour when one of the inputs is a vector
(
np
.
arange
(
3
*
5
)
.
reshape
((
5
,
3
))
.
astype
(
config
.
floatX
),
np
.
arange
(
3
)
.
astype
(
config
.
floatX
),
),
(
np
.
arange
(
5
)
.
astype
(
config
.
floatX
),
np
.
arange
(
3
*
5
)
.
reshape
((
5
,
3
))
.
astype
(
config
.
floatX
),
),
# check if behaviour is correct N-D arrays where N > 2.
(
np
.
arange
(
2
*
2
*
4
)
.
reshape
((
2
,
2
,
4
))
.
astype
(
config
.
floatX
),
np
.
arange
(
2
*
2
*
4
)
.
reshape
((
2
,
4
,
2
))
.
astype
(
config
.
floatX
),
),
],
)
def
test_op
(
self
,
x1
,
x2
):
self
.
_validate_output
(
x1
,
x2
)
def
test_scalar_error
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"cannot be scalar"
):
self
.
op
(
4
,
[
4
,
1
])
@pytest.mark.parametrize
(
"dtype"
,
(
np
.
float16
,
np
.
float32
,
np
.
float64
))
def
test_dtype_param
(
self
,
dtype
):
sol
=
self
.
op
([
1
,
2
,
3
],
[
3
,
2
,
1
],
dtype
=
dtype
)
assert
sol
.
eval
()
.
dtype
==
dtype
@pytest.mark.parametrize
(
"x1_shape,x2_shape,exp_res,error_regex"
,
[
((
1
,),
(
3
,),
None
,
"inputs must have the same length"
),
((
2
,),
(
3
,
1
),
None
,
"length of input 1.*2nd-last dimension of input 2"
),
((
2
,
5
),
(
3
,),
None
,
"length of input 2.*of the last dimension of input 1"
),
(
(
2
,
5
),
(
3
,
4
),
None
,
"number of columns of input 1 .* number of rows of input 2"
,
),
(
(
2
,
1
,
3
),
(
5
,
4
),
None
,
"number of rows of input 2 .* last dimension of input 1"
,
),
(
(
2
,
5
),
(
2
,
4
,
3
),
None
,
"number of columns of input 1 .* 2nd-last dimension of input 2"
,
),
(
(
3
,
2
,
4
,
5
),
(
1
,
6
,
7
),
None
,
"length of the last dimension of input 1 .* 2nd-last dimension of input 2"
,
),
(
(
4
,
5
,
4
),
(
3
,
2
,
2
),
None
,
"cannot be broadcast to a single shape"
,
),
(
(
4
,
None
,
2
),
(
4
,
2
,
None
),
(
4
,
None
,
None
),
None
,
),
],
)
def
test_get_output_shape
(
self
,
x1_shape
,
x2_shape
,
exp_res
,
error_regex
):
x1
=
tensor
(
dtype
=
np
.
float64
,
shape
=
x1_shape
)
x2
=
tensor
(
dtype
=
np
.
float64
,
shape
=
x2_shape
)
if
error_regex
is
not
None
:
with
pytest
.
raises
(
ValueError
,
match
=
error_regex
):
self
.
op_class
.
_get_output_shape
(
x1
,
x2
,
(
x1_shape
,
x2_shape
),
validate
=
True
)
else
:
assert
(
self
.
op_class
.
_get_output_shape
(
x1
,
x2
,
(
x1_shape
,
x2_shape
),
validate
=
True
)
==
exp_res
)
def
test_infer_shape
(
self
):
for
shape_x1
,
shape_x2
in
[
((
5
,),
(
5
,)),
((
5
,),
(
2
,
5
,
3
)),
((
2
,
5
,
3
),
(
3
,)),
((
2
,
5
),
(
5
,
4
)),
((
2
,
5
),
(
2
,
5
,
3
)),
((
2
,
1
,
3
),
(
3
,
4
)),
((
3
,
2
,
4
,
5
),
(
1
,
5
,
7
)),
]:
a
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
shape_x1
)
b
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
shape_x2
)
x1
=
self
.
rng
.
random
(
shape_x1
)
.
astype
(
config
.
floatX
)
x2
=
self
.
rng
.
random
(
shape_x2
)
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
(
[
a
,
b
],
[
self
.
op
(
a
,
b
)],
[
x1
,
x2
],
self
.
op_class
,
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论