提交 8e5e8a40 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Only do reshapes in `tensordot` when needed

上级 65b96c1c
...@@ -2158,13 +2158,11 @@ def tensordot( ...@@ -2158,13 +2158,11 @@ def tensordot(
a = as_tensor_variable(a) a = as_tensor_variable(a)
b = as_tensor_variable(b) b = as_tensor_variable(b)
runtime_shape_a = a.shape runtime_shape_a = a.shape
bcast_a = a.broadcastable
static_shape_a = a.type.shape static_shape_a = a.type.shape
ndim_a = a.ndim ndim_a = a.type.ndim
runtime_shape_b = b.shape runtime_shape_b = b.shape
bcast_b = b.broadcastable
static_shape_b = b.type.shape static_shape_b = b.type.shape
ndim_b = b.ndim ndim_b = b.type.ndim
if na != nb: if na != nb:
raise ValueError( raise ValueError(
"The number of axes supplied for tensordot must be equal for each tensor. " "The number of axes supplied for tensordot must be equal for each tensor. "
...@@ -2172,48 +2170,67 @@ def tensordot( ...@@ -2172,48 +2170,67 @@ def tensordot(
) )
axes_a = list(normalize_axis_tuple(axes_a, ndim_a)) axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
axes_b = list(normalize_axis_tuple(axes_b, ndim_b)) axes_b = list(normalize_axis_tuple(axes_b, ndim_b))
# The operation is only valid if the original dimensions match in length
# The ravelling of the dimensions to coerce the operation into a single dot
# could mask such errors, so we add an Assert if needed.
must_assert_runtime = False must_assert_runtime = False
for k in range(na): for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
ax_a = axes_a[k] if (
ax_b = axes_b[k]
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
static_shape_a[ax_a] is not None static_shape_a[ax_a] is not None
and static_shape_b[ax_b] is not None and static_shape_b[ax_b] is not None
and static_shape_a[ax_a] != static_shape_b[ax_b] and static_shape_a[ax_a] != static_shape_b[ax_b]
): ):
raise ValueError( raise ValueError(
"Input arrays have inconsistent broadcastable pattern or type shape along the axes " "Input arrays have inconsistent type shape along the axes "
"that are to be reduced with tensordot." "that are to be reduced with tensordot."
) )
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None: elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
if must_assert_runtime: if must_assert_runtime:
a = Assert( a = Assert(
"Input array shape along reduced axes of tensordot are not equal" "Input array shape along reduced axes of tensordot are not equal"
)(a, eq(a.shape[ax_a], b.shape[ax_b])) )(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
must_assert_runtime = True must_assert_runtime = True
# Move the axes to sum over to the end of "a" # Convert tensordot into a stacked dot product.
# and to the front of "b" # We stack the summed axes and the non-summed axes of each tensor separately,
notin = [k for k in range(ndim_a) if k not in axes_a] # and place the summed axes at the end of a and the beginning of b
newaxes_a = notin + axes_a non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
N2 = 1 non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
for axis in axes_a: transpose_axes_a = non_summed_axes_a + axes_a
N2 *= runtime_shape_a[axis] # We only need a reshape when we need to combine summed or non-summed dims
newshape_a = (-1, N2) # or introduce a new dimension (expand_dims) when doing a non-scalar outer product (len(axes) = 0)
olda = [runtime_shape_a[axis] for axis in notin] a_needs_reshape = (ndim_a != 0) and (
(len(non_summed_axes_a) > 1) or (len(axes_a) != 1)
notin = [k for k in range(ndim_b) if k not in axes_b] )
newaxes_b = axes_b + notin
N2 = 1 non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
for axis in axes_b: non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
N2 *= runtime_shape_b[axis] transpose_axes_b = axes_b + non_summed_axes_b
newshape_b = (N2, -1) b_needs_reshape = (ndim_b != 0) and (
oldb = [runtime_shape_b[axis] for axis in notin] (len(non_summed_axes_b) > 1) or (len(axes_b) != 1)
)
at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b) # summed_size_a and summed_size_b must be the same,
res = _dot(at, bt) # but to facilitate reasoning about useless reshapes we compute both from their shapes
return res.reshape(olda + oldb) at = a.transpose(transpose_axes_a)
if a_needs_reshape:
non_summed_size_a = variadic_mul(*non_summed_dims_a)
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
at = at.reshape((non_summed_size_a, summed_size_a))
bt = b.transpose(transpose_axes_b)
if b_needs_reshape:
non_summed_size_b = variadic_mul(*non_summed_dims_b)
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
bt = bt.reshape((summed_size_b, non_summed_size_b))
res = dot(at, bt)
if a_needs_reshape or b_needs_reshape:
res = res.reshape(non_summed_dims_a + non_summed_dims_b)
return res
def outer(x, y): def outer(x, y):
......
...@@ -19,7 +19,7 @@ from pytensor.compile.mode import get_default_mode ...@@ -19,7 +19,7 @@ from pytensor.compile.mode import get_default_mode
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, grad, numeric_grad from pytensor.gradient import NullTypeGradError, grad, numeric_grad
from pytensor.graph.basic import Variable, ancestors, applys_between from pytensor.graph.basic import Variable, ancestors, applys_between, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.link.c.basic import DualLinker from pytensor.link.c.basic import DualLinker
...@@ -2278,7 +2278,7 @@ class TestTensordot: ...@@ -2278,7 +2278,7 @@ class TestTensordot:
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Input arrays have inconsistent broadcastable pattern or type shape", match="Input arrays have inconsistent type shape",
): ):
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1) tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)
...@@ -2323,6 +2323,41 @@ class TestTensordot: ...@@ -2323,6 +2323,41 @@ class TestTensordot:
else: else:
assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv})) assert np.allclose(np.tensordot(xv, yv, axes=axes), z.eval({x: xv, y: yv}))
def test_eager_simplification(self):
# Test that cases where tensordot isn't needed, it returns a simple graph
scl = tensor(shape=())
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
# scalar product
out = tensordot(scl, scl, axes=[[], []])
assert equal_computations([out], [scl * scl])
# vector-vector product
out = tensordot(vec, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, vec)])
# matrix-vector product
out = tensordot(mat, vec, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, vec)])
out = tensordot(mat, vec, axes=[[-2], [-1]])
assert equal_computations([out], [dot(mat.T, vec)])
# vector-matrix product
out = tensordot(vec, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(vec, mat)])
out = tensordot(vec, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(vec, mat.T)])
# matrix-matrix product
out = tensordot(mat, mat, axes=[[-1], [-2]])
assert equal_computations([out], [dot(mat, mat)])
out = tensordot(mat, mat, axes=[[-1], [-1]])
assert equal_computations([out], [dot(mat, mat.T)])
def test_smallest(): def test_smallest():
x = dvector() x = dvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论