提交 c4a3444b authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use explicit imports in test_einsum

上级 884dee90
......@@ -5,13 +5,14 @@ import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.basic import moveaxis
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor
# Fail for unexpected warnings in this file
......@@ -80,8 +81,8 @@ def test_general_dot():
# X has two batch dims
# Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11))
x = tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = tensor("y", shape=(4, 13, 5, 7, 11))
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])
fn = pytensor.function([x, y], out)
......@@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature):
static_shapes = [[None] * len(shape) for shape in shapes]
operands = [
pt.tensor(name, shape=static_shape)
tensor(name, shape=static_shape)
for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False)
]
out = pt.einsum(signature, *operands)
out = einsum(signature, *operands)
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
rng = np.random.default_rng(37)
......@@ -160,8 +161,8 @@ def test_batch_dim():
"x": (7, 3, 5),
"y": (5, 2),
}
x, y = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum("mij,jk->mik", x, y)
x, y = (tensor(name, shape=shape) for name, shape in shapes.items())
out = einsum("mij,jk->mik", x, y)
assert out.type.shape == (7, 3, 2)
......@@ -195,24 +196,24 @@ def test_einsum_conv():
def test_ellipsis():
rng = np.random.default_rng(159)
x = pt.tensor("x", shape=(3, 5, 7, 11))
y = pt.tensor("y", shape=(3, 5, 11, 13))
x = tensor("x", shape=(3, 5, 7, 11))
y = tensor("y", shape=(3, 5, 11, 13))
x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX)
expected_out = np.matmul(x_test, y_test)
with pytest.raises(ValueError):
pt.einsum("mp,pn->mn", x, y)
einsum("mp,pn->mn", x, y)
out = pt.einsum("...mp,...pn->...mn", x, y)
out = einsum("...mp,...pn->...mn", x, y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
)
# Put batch axes in the middle
new_x = pt.moveaxis(x, -2, 0)
new_y = pt.moveaxis(y, -2, 0)
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
new_x = moveaxis(x, -2, 0)
new_y = moveaxis(y, -2, 0)
out = einsum("m...p,p...n->m...n", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}),
expected_out.transpose(-2, 0, 1, -1),
......@@ -220,7 +221,7 @@ def test_ellipsis():
rtol=RTOL,
)
out = pt.einsum("m...p,p...n->mn", new_x, new_y)
out = einsum("m...p,p...n->mn", new_x, new_y)
np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
)
......@@ -236,9 +237,9 @@ def test_broadcastable_dims():
# can lead to suboptimal paths. We check we issue a warning for the following example:
# https://github.com/dgasmith/opt_einsum/issues/220
rng = np.random.default_rng(222)
a = pt.tensor("a", shape=(32, 32, 32))
b = pt.tensor("b", shape=(1000, 32))
c = pt.tensor("c", shape=(1, 32))
a = tensor("a", shape=(32, 32, 32))
b = tensor("b", shape=(1000, 32))
c = tensor("c", shape=(1, 32))
a_test = rng.normal(size=a.type.shape).astype(floatX)
b_test = rng.normal(size=b.type.shape).astype(floatX)
......@@ -248,11 +249,11 @@ def test_broadcastable_dims():
with pytest.warns(
UserWarning, match="This can result in a suboptimal contraction path"
):
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c)
suboptimal_out = einsum("ijk,bj,bk->i", a, b, c)
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
# If we use a distinct letter we get the optimal path
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c)
optimal_out = einsum("ijk,bj,ck->i", a, b, c)
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论