Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e3fb4985
提交
e3fb4985
authored
1月 18, 2024
作者:
lucianopaz
提交者:
Ricardo Vieira
1月 20, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix tensordot implementation
上级
f799219e
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
214 行增加
和
47 行删除
+214
-47
math.py
pytensor/tensor/math.py
+121
-46
test_math.py
tests/tensor/test_math.py
+93
-1
没有找到文件。
pytensor/tensor/math.py
浏览文件 @
e3fb4985
import
builtins
import
warnings
from
typing
import
TYPE_CHECKING
,
Optional
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
from
numpy.core.numeric
import
normalize_axis_tuple
from
pytensor
import
config
,
printing
from
pytensor
import
scalar
as
ps
...
...
@@ -15,7 +17,9 @@ from pytensor.link.c.params_type import ParamsType
from
pytensor.link.c.type
import
Generic
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.printing
import
pprint
from
pytensor.raise_op
import
Assert
from
pytensor.scalar.basic
import
BinaryScalarOp
from
pytensor.tensor
import
TensorLike
from
pytensor.tensor.basic
import
(
alloc
,
arange
,
...
...
@@ -47,7 +51,11 @@ from pytensor.tensor.type import (
)
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.utils
import
as_list
from
pytensor.tensor.variable
import
TensorConstant
,
_tensor_py_operators
from
pytensor.tensor.variable
import
(
TensorConstant
,
TensorVariable
,
_tensor_py_operators
,
)
if
TYPE_CHECKING
:
...
...
@@ -2266,57 +2274,47 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
)
def
tensordot
(
a
,
b
,
axes
=
2
):
def
tensordot
(
a
:
TensorLike
,
b
:
TensorLike
,
axes
:
Union
[
int
,
Sequence
[
Sequence
[
int
]]]
=
2
)
->
TensorVariable
:
"""
Compute a generalized dot product over provided axes.
Compute tensor dot product along specified axes.
Implementation is mostly taken from numpy version 1.26.0
Given two tensors a and b, tensordot computes a generalized dot product over
the provided axes. PyTensor's implementation reduces all expressions to
matrix or vector dot products and is based on code from Tijmen Tieleman's
gnumpy (http://www.cs.toronto.edu/~tijmen/gnumpy.html).
Given two tensors, `a` and `b`, and a sequence object containing
two sequence objects, ``(a_axes, b_axes)``, sum the products of
`a`'s and `b`'s elements (components) over the axes specified by
``a_axes`` and ``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions
of `a` and the first ``N`` dimensions of `b` are summed over.
Parameters
----------
a: symbolic tensor
The first tensor variable.
b: symbolic tensor
The second tensor variable
axes: int or array-like of length 2
If an integer, the number of axes to sum over.
If an array, it must have two array elements containing the axes
to sum over in each tensor.
Note that the default value of 2 is not guaranteed to work
for all values of a and b, and an error will be raised if
that is the case. The reason for keeping the default is to
maintain the same signature as numpy's tensordot function
(and np.tensordot raises analogous errors for non-compatible
inputs).
If an integer i, it is converted to an array containing
the last i dimensions of the first tensor and the first
i dimensions of the second tensor:
axes = [list(range(a.ndim - i, b.ndim)), list(range(i))]
If an array, its two elements must contain compatible axes
of the two tensors. For example, [[1, 2], [2, 0]] means sum
over the 2nd and 3rd axes of a and the 3rd and 1st axes of b.
(Remember axes are zero-indexed!) The 2nd axis of a and the
3rd axis of b must have the same shape; the same is true for
the 3rd axis of a and the 1st axis of b.
a, b : tensor_like
Tensors to "dot".
axes : int or (2,) array_like
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) array_like
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements array_like must be of the same length.
Returns
-------
symbolic tensor
A tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less any dimensions that were summed over).
output : TensorVariable
The tensor dot product of the input.
Its shape will be equal to the concatenation of `a` and `b` shapes
(ignoring the dimensions that were summed over given in ``a_axes``
and ``b_axes``)
Examples
--------
It may be helpful to consider an example to see what tensordot does.
PyTensor's implementation is identical to NumPy's. Here
a
has shape (2, 3, 4)
and
b
has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
PyTensor's implementation is identical to NumPy's. Here
``a``
has shape (2, 3, 4)
and
``b``
has shape (5, 6, 4, 3). The axes to sum over are [[1, 2], [3, 2]] --
note that a.shape[1] == b.shape[3] and a.shape[2] == b.shape[2]; these axes
are compatible. The resulting tensor will have shape (2, 5, 6) -- the
dimensions that are not being summed:
...
...
@@ -2347,10 +2345,9 @@ def tensordot(a, b, axes=2):
true
This specific implementation avoids a loop by transposing a and b such that
the summed axes of a are last and the summed axes of b are first. The
resulting arrays are reshaped to 2 dimensions (or left as vectors, if
appropriate) and a matrix or vector dot product is taken. The result is
reshaped back to the required output dimensions.
the summed axes of ``a`` are last and the summed axes of ``b`` are first. The
resulting arrays are reshaped to 2 dimensions and a matrix dot product is taken.
The result is reshaped back to the required output dimensions.
In an extreme case, no axes may be specified. The resulting tensor
will have shape equal to the concatenation of the shapes of a and b:
...
...
@@ -2366,7 +2363,85 @@ def tensordot(a, b, axes=2):
See the documentation of numpy.tensordot for more examples.
"""
return
_tensordot_as_dot
(
a
,
b
,
axes
,
dot
=
dot
,
batched
=
False
)
try
:
iter
(
axes
)
except
Exception
:
axes_a
=
list
(
range
(
-
axes
,
0
))
axes_b
=
list
(
range
(
0
,
axes
))
else
:
axes_a
,
axes_b
=
axes
try
:
na
=
len
(
axes_a
)
axes_a
=
list
(
axes_a
)
except
TypeError
:
axes_a
=
[
axes_a
]
na
=
1
try
:
nb
=
len
(
axes_b
)
axes_b
=
list
(
axes_b
)
except
TypeError
:
axes_b
=
[
axes_b
]
nb
=
1
a
=
as_tensor_variable
(
a
)
b
=
as_tensor_variable
(
b
)
runtime_shape_a
=
a
.
shape
bcast_a
=
a
.
broadcastable
static_shape_a
=
a
.
type
.
shape
ndim_a
=
a
.
ndim
runtime_shape_b
=
b
.
shape
bcast_b
=
b
.
broadcastable
static_shape_b
=
b
.
type
.
shape
ndim_b
=
b
.
ndim
if
na
!=
nb
:
raise
ValueError
(
"The number of axes supplied for tensordot must be equal for each tensor. "
f
"Got {na} and {nb} respectively."
)
axes_a
=
list
(
normalize_axis_tuple
(
axes_a
,
ndim_a
))
axes_b
=
list
(
normalize_axis_tuple
(
axes_b
,
ndim_b
))
must_assert_runtime
=
False
for
k
in
range
(
na
):
ax_a
=
axes_a
[
k
]
ax_b
=
axes_b
[
k
]
if
(
bcast_a
[
ax_a
]
!=
bcast_b
[
ax_b
])
or
(
static_shape_a
[
ax_a
]
is
not
None
and
static_shape_b
[
ax_b
]
is
not
None
and
static_shape_a
[
ax_a
]
!=
static_shape_b
[
ax_b
]
):
raise
ValueError
(
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
"that are to be reduced with tensordot."
)
elif
static_shape_a
[
ax_a
]
is
None
or
static_shape_b
[
ax_b
]
is
None
:
if
must_assert_runtime
:
a
=
Assert
(
"Input array shape along reduced axes of tensordot are not equal"
)(
a
,
eq
(
a
.
shape
[
ax_a
],
b
.
shape
[
ax_b
]))
must_assert_runtime
=
True
# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin
=
[
k
for
k
in
range
(
ndim_a
)
if
k
not
in
axes_a
]
newaxes_a
=
notin
+
axes_a
N2
=
1
for
axis
in
axes_a
:
N2
*=
runtime_shape_a
[
axis
]
newshape_a
=
(
-
1
,
N2
)
olda
=
[
runtime_shape_a
[
axis
]
for
axis
in
notin
]
notin
=
[
k
for
k
in
range
(
ndim_b
)
if
k
not
in
axes_b
]
newaxes_b
=
axes_b
+
notin
N2
=
1
for
axis
in
axes_b
:
N2
*=
runtime_shape_b
[
axis
]
newshape_b
=
(
N2
,
-
1
)
oldb
=
[
runtime_shape_b
[
axis
]
for
axis
in
notin
]
at
=
a
.
transpose
(
newaxes_a
)
.
reshape
(
newshape_a
)
bt
=
b
.
transpose
(
newaxes_b
)
.
reshape
(
newshape_b
)
res
=
_dot
(
at
,
bt
)
return
res
.
reshape
(
olda
+
oldb
)
def
outer
(
x
,
y
):
...
...
tests/tensor/test_math.py
浏览文件 @
e3fb4985
...
...
@@ -18,18 +18,20 @@ from pytensor.compile.mode import get_default_mode
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
NullTypeGradError
,
grad
,
numeric_grad
from
pytensor.graph.basic
import
Variable
,
applys_between
from
pytensor.graph.basic
import
Variable
,
a
ncestors
,
a
pplys_between
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.link.c.basic
import
DualLinker
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.printing
import
pprint
from
pytensor.raise_op
import
Assert
from
pytensor.tensor
import
blas
,
blas_c
from
pytensor.tensor.basic
import
(
as_tensor_variable
,
constant
,
eye
,
get_underlying_scalar_constant_value
,
ones
,
switch
,
)
from
pytensor.tensor.blas
import
Dot22
...
...
@@ -2208,6 +2210,96 @@ class TestTensordot:
zv
=
f
(
xv
,
yv
)
assert
np
.
allclose
(
np
.
tensordot
(
xv
,
yv
,
axes
=
axes
),
zv
)
def
test_type_shape
(
self
):
x
=
ones
(
shape
=
(
7
,
3
,
2
))
y
=
ones
(
shape
=
(
10
,
2
,
)
)
xv
=
x
.
eval
()
yv
=
y
.
eval
()
sy
=
tensor
(
"sy"
,
shape
=
(
None
,
2
))
axes
=
[[
-
1
],
[
-
1
]]
z
=
tensordot
(
x
,
y
,
axes
=
axes
)
sz
=
tensordot
(
x
,
sy
,
axes
=
axes
)
assert
(
len
(
{
node
for
node
in
ancestors
([
z
])
if
node
.
owner
and
isinstance
(
node
.
owner
.
op
,
Assert
)
}
)
==
0
)
assert
z
.
type
.
shape
==
(
7
,
3
,
10
)
assert
z
.
broadcastable
==
(
False
,
False
,
False
)
assert
np
.
allclose
(
np
.
tensordot
(
xv
,
yv
,
axes
=
axes
),
z
.
eval
())
assert
(
len
(
{
node
for
node
in
ancestors
([
sz
])
if
node
.
owner
and
isinstance
(
node
.
owner
.
op
,
Assert
)
}
)
==
0
)
assert
sz
.
type
.
shape
==
(
7
,
3
,
None
)
assert
z
.
broadcastable
==
(
False
,
False
,
False
)
assert
np
.
allclose
(
np
.
tensordot
(
xv
,
yv
,
axes
=
axes
),
sz
.
eval
({
sy
:
yv
}))
with
pytest
.
raises
(
ValueError
,
match
=
"Input arrays have inconsistent broadcastable pattern or type shape"
,
):
tensordot
(
ones
(
shape
=
(
7
,
4
)),
ones
(
shape
=
(
7
,
4
)),
axes
=
1
)
@pytest.mark.parametrize
(
[
"axes"
,
"has_assert"
,
"values"
,
"expected_fail"
],
[
([[
1
],
[
2
]],
False
,
(
np
.
ones
((
7
,
3
,
2
)),
np
.
ones
((
7
,
2
,
3
))),
False
),
([[
0
,
2
],
[
0
,
1
]],
True
,
(
np
.
ones
((
7
,
3
,
2
)),
np
.
ones
((
7
,
2
,
3
))),
False
),
([[
0
],
[
0
]],
False
,
(
np
.
ones
((
7
,
3
,
1
)),
np
.
ones
((
100
,
1
,
3
))),
True
),
([[
1
,
2
],
[
1
,
2
]],
True
,
(
np
.
ones
((
7
,
3
,
2
)),
np
.
ones
((
7
,
2
,
3
))),
True
),
],
)
def
test_shape_assert
(
self
,
axes
,
has_assert
,
values
,
expected_fail
):
x
=
tensor
(
shape
=
(
7
,
3
,
None
))
y
=
tensor
(
shape
=
(
None
,
None
,
3
))
xv
,
yv
=
values
xv
=
xv
.
astype
(
x
.
dtype
)
yv
=
yv
.
astype
(
x
.
dtype
)
z
=
tensordot
(
x
,
y
,
axes
=
axes
)
found_asserts
=
{
node
for
node
in
ancestors
([
z
])
if
node
.
owner
and
isinstance
(
node
.
owner
.
op
,
Assert
)
}
if
has_assert
:
assert
found_asserts
else
:
assert
not
found_asserts
if
expected_fail
:
if
has_assert
:
with
pytest
.
raises
(
AssertionError
,
match
=
"Input array shape along reduced axes of tensordot are not equal"
,
):
z
.
eval
({
x
:
xv
,
y
:
yv
})
else
:
with
pytest
.
raises
(
ValueError
):
z
.
eval
({
x
:
xv
,
y
:
yv
})
else
:
assert
np
.
allclose
(
np
.
tensordot
(
xv
,
yv
,
axes
=
axes
),
z
.
eval
({
x
:
xv
,
y
:
yv
}))
def
test_smallest
():
x
=
dvector
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论