提交 8e5e8a40 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only do reshapes in `tensordot` when needed

上级 65b96c1c
......@@ -2158,13 +2158,11 @@ def tensordot(
a = as_tensor_variable(a)
b = as_tensor_variable(b)
runtime_shape_a = a.shape
bcast_a = a.broadcastable
static_shape_a = a.type.shape
ndim_a = a.ndim
ndim_a = a.type.ndim
runtime_shape_b = b.shape
bcast_b = b.broadcastable
static_shape_b = b.type.shape
ndim_b = b.ndim
ndim_b = b.type.ndim
if na != nb:
raise ValueError(
"The number of axes supplied for tensordot must be equal for each tensor. "
......@@ -2172,48 +2170,67 @@ def tensordot(
)
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))
# The operation is only valid if the original dimensions match in length
# The ravelling of the dimensions to coerce the operation into a single dot
# could mask such errors, so we add an Assert if needed.
must_assert_runtime = False
for k in range(na):
ax_a = axes_a[k]
ax_b = axes_b[k]
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
if (
static_shape_a[ax_a] is not None
and static_shape_b[ax_b] is not None
and static_shape_a[ax_a] != static_shape_b[ax_b]
):
raise ValueError(
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
"Input arrays have inconsistent type shape along the axes "
"that are to be reduced with tensordot."
)
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
if must_assert_runtime:
a = Assert(
"Input array shape along reduced axes of tensordot are not equal"
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
must_assert_runtime = True
# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin = [k for k in range(ndim_a) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= runtime_shape_a[axis]
newshape_a = (-1, N2)
olda = [runtime_shape_a[axis] for axis in notin]
notin = [k for k in range(ndim_b) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= runtime_shape_b[axis]
newshape_b = (N2, -1)
oldb = [runtime_shape_b[axis] for axis in notin]
at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = _dot(at, bt)
return res.reshape(olda + oldb)
# Convert tensordot into a stacked dot product.
# We stack the summed axes and the non-summed axes of each tensor separately,
# and place the summed axes at the end of a and the beginning of b
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
transpose_axes_a = non_summed_axes_a + axes_a
# We only need a reshape when we need to combine summed or non-summed dims
# or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0)
a_needs_reshape = (ndim_a != 0) and (
(len(non_summed_axes_a) > 1) or (len(axes_a) != 1)
)
non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
transpose_axes_b = axes_b + non_summed_axes_b
b_needs_reshape = (ndim_b != 0) and (
(len(non_summed_axes_b) > 1) or (len(axes_b) != 1)
)
# summed_size_a and summed_size_b must be the same,
# but to facilitate reasoning about useless reshapes we compute both from their shapes
at = a.transpose(transpose_axes_a)
if a_needs_reshape:
non_summed_size_a = variadic_mul(*non_summed_dims_a)
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
at = at.reshape((non_summed_size_a, summed_size_a))
bt = b.transpose(transpose_axes_b)
if b_needs_reshape:
non_summed_size_b = variadic_mul(*non_summed_dims_b)
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
bt = bt.reshape((summed_size_b, non_summed_size_b))
res = dot(at, bt)
if a_needs_reshape or b_needs_reshape:
res = res.reshape(non_summed_dims_a + non_summed_dims_b)
return res
def outer(x, y):
......
......@@ -19,7 +19,7 @@ from pytensor.compile.mode import get_default_mode
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, ancestors, applys_between
from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker
......@@ -2278,7 +2278,7 @@ class TestTensordot:
with pytest.raises(
ValueError,
match="Input arrays have inconsistent broadcastable pattern or type shape",
match="Input arrays have inconsistent type shape",
):
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)
......@@ -2323,6 +2323,41 @@ class TestTensordot:
else:
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))
def test_eager_simplification(self):
# Test that cases where tensordot isn't needed, it returns a simple graph
scl = tensor(shape=())
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
# scalar product
out = tensordot(scl, scl, axes=[[], []])
assert equal_computations([out], [scl * scl])
# vector-vector product
out = tensordot(vec, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, vec)])
# matrix-vector product
out = tensordot(mat, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, vec)])
out = tensordot(mat, vec, axes=[[-2], [-1]])
assert equal_computations([out], [dot(mat.T, vec)])
# vector-matrix product
out = tensordot(vec, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(vec, mat)])
out = tensordot(vec, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, mat.T)])
# matrix-matrix product
out = tensordot(mat, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(mat, mat)])
out = tensordot(mat, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, mat.T)])
def test_smallest():
x = dvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论