提交 c4a3444b authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use explicit imports in test_einsum

上级 884dee90
...@@ -5,13 +5,14 @@ import numpy as np ...@@ -5,13 +5,14 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
import pytensor.tensor as pt
from pytensor import Mode, config, function from pytensor import Mode, config, function
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.op import HasInnerGraph from pytensor.graph.op import HasInnerGraph
from pytensor.tensor.basic import moveaxis
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum from pytensor.tensor.einsum import _delta, _general_dot, _iota, einsum
from pytensor.tensor.shape import Reshape from pytensor.tensor.shape import Reshape
from pytensor.tensor.type import tensor
# Fail for unexpected warnings in this file # Fail for unexpected warnings in this file
...@@ -80,8 +81,8 @@ def test_general_dot(): ...@@ -80,8 +81,8 @@ def test_general_dot():
# X has two batch dims # X has two batch dims
# Y has one batch dim # Y has one batch dim
x = pt.tensor("x", shape=(5, 4, 2, 11, 13, 3)) x = tensor("x", shape=(5, 4, 2, 11, 13, 3))
y = pt.tensor("y", shape=(4, 13, 5, 7, 11)) y = tensor("y", shape=(4, 13, 5, 7, 11))
out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)]) out = _general_dot((x, y), tensordot_axes, [(0, 1), (0,)])
fn = pytensor.function([x, y], out) fn = pytensor.function([x, y], out)
...@@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature): ...@@ -135,10 +136,10 @@ def test_einsum_signatures(static_shape_known, signature):
static_shapes = [[None] * len(shape) for shape in shapes] static_shapes = [[None] * len(shape) for shape in shapes]
operands = [ operands = [
pt.tensor(name, shape=static_shape) tensor(name, shape=static_shape)
for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False) for name, static_shape in zip(ascii_lowercase, static_shapes, strict=False)
] ]
out = pt.einsum(signature, *operands) out = einsum(signature, *operands)
assert out.owner.op.optimized == static_shape_known or len(operands) <= 2 assert out.owner.op.optimized == static_shape_known or len(operands) <= 2
rng = np.random.default_rng(37) rng = np.random.default_rng(37)
...@@ -160,8 +161,8 @@ def test_batch_dim(): ...@@ -160,8 +161,8 @@ def test_batch_dim():
"x": (7, 3, 5), "x": (7, 3, 5),
"y": (5, 2), "y": (5, 2),
} }
x, y = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) x, y = (tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum("mij,jk->mik", x, y) out = einsum("mij,jk->mik", x, y)
assert out.type.shape == (7, 3, 2) assert out.type.shape == (7, 3, 2)
...@@ -195,24 +196,24 @@ def test_einsum_conv(): ...@@ -195,24 +196,24 @@ def test_einsum_conv():
def test_ellipsis(): def test_ellipsis():
rng = np.random.default_rng(159) rng = np.random.default_rng(159)
x = pt.tensor("x", shape=(3, 5, 7, 11)) x = tensor("x", shape=(3, 5, 7, 11))
y = pt.tensor("y", shape=(3, 5, 11, 13)) y = tensor("y", shape=(3, 5, 11, 13))
x_test = rng.normal(size=x.type.shape).astype(floatX) x_test = rng.normal(size=x.type.shape).astype(floatX)
y_test = rng.normal(size=y.type.shape).astype(floatX) y_test = rng.normal(size=y.type.shape).astype(floatX)
expected_out = np.matmul(x_test, y_test) expected_out = np.matmul(x_test, y_test)
with pytest.raises(ValueError): with pytest.raises(ValueError):
pt.einsum("mp,pn->mn", x, y) einsum("mp,pn->mn", x, y)
out = pt.einsum("...mp,...pn->...mn", x, y) out = einsum("...mp,...pn->...mn", x, y)
np.testing.assert_allclose( np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
) )
# Put batch axes in the middle # Put batch axes in the middle
new_x = pt.moveaxis(x, -2, 0) new_x = moveaxis(x, -2, 0)
new_y = pt.moveaxis(y, -2, 0) new_y = moveaxis(y, -2, 0)
out = pt.einsum("m...p,p...n->m...n", new_x, new_y) out = einsum("m...p,p...n->m...n", new_x, new_y)
np.testing.assert_allclose( np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), out.eval({x: x_test, y: y_test}),
expected_out.transpose(-2, 0, 1, -1), expected_out.transpose(-2, 0, 1, -1),
...@@ -220,7 +221,7 @@ def test_ellipsis(): ...@@ -220,7 +221,7 @@ def test_ellipsis():
rtol=RTOL, rtol=RTOL,
) )
out = pt.einsum("m...p,p...n->mn", new_x, new_y) out = einsum("m...p,p...n->mn", new_x, new_y)
np.testing.assert_allclose( np.testing.assert_allclose(
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
) )
...@@ -236,9 +237,9 @@ def test_broadcastable_dims(): ...@@ -236,9 +237,9 @@ def test_broadcastable_dims():
# can lead to suboptimal paths. We check we issue a warning for the following example: # can lead to suboptimal paths. We check we issue a warning for the following example:
# https://github.com/dgasmith/opt_einsum/issues/220 # https://github.com/dgasmith/opt_einsum/issues/220
rng = np.random.default_rng(222) rng = np.random.default_rng(222)
a = pt.tensor("a", shape=(32, 32, 32)) a = tensor("a", shape=(32, 32, 32))
b = pt.tensor("b", shape=(1000, 32)) b = tensor("b", shape=(1000, 32))
c = pt.tensor("c", shape=(1, 32)) c = tensor("c", shape=(1, 32))
a_test = rng.normal(size=a.type.shape).astype(floatX) a_test = rng.normal(size=a.type.shape).astype(floatX)
b_test = rng.normal(size=b.type.shape).astype(floatX) b_test = rng.normal(size=b.type.shape).astype(floatX)
...@@ -248,11 +249,11 @@ def test_broadcastable_dims(): ...@@ -248,11 +249,11 @@ def test_broadcastable_dims():
with pytest.warns( with pytest.warns(
UserWarning, match="This can result in a suboptimal contraction path" UserWarning, match="This can result in a suboptimal contraction path"
): ):
suboptimal_out = pt.einsum("ijk,bj,bk->i", a, b, c) suboptimal_out = einsum("ijk,bj,bk->i", a, b, c)
assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}] assert not [set(p) for p in suboptimal_out.owner.op.path] == [{0, 2}, {0, 1}]
# If we use a distinct letter we get the optimal path # If we use a distinct letter we get the optimal path
optimal_out = pt.einsum("ijk,bj,ck->i", a, b, c) optimal_out = einsum("ijk,bj,ck->i", a, b, c)
assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}] assert [set(p) for p in optimal_out.owner.op.path] == [{0, 2}, {0, 1}]
suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test}) suboptimal_eval = suboptimal_out.eval({a: a_test, b: b_test, c: c_test})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论