提交 d607e23c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not generate C code for BatchedDot when BLAS flags are missing

上级 3f960ded
......@@ -1795,6 +1795,10 @@ class BatchedDot(COp):
return ldflags(libs=False, include_dir=True)
def c_code(self, node, name, inp, out, sub):
# Can only compile if linked to blas libraries
if len(self.c_libraries()) <= 0:
raise NotImplementedError()
_x, _y = inp
(_z,) = out
fail = sub["fail"]
......
......@@ -23,6 +23,7 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor import inplace
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.blas import (
BatchedDot,
Dot22,
Dot22Scalar,
Gemm,
......@@ -2700,6 +2701,30 @@ def test_batched_dot_not_contiguous():
check_first_dim(inverted)
def test_batched_dot_blas_flags():
"""Test that BatchedDot works regardless of presence of BLAS flags"""
mode = "FAST_RUN"
rng = np.random.default_rng(2708)
x = tensor("x", shape=(2, 5, 3))
y = tensor("y", shape=(2, 3, 1))
out = batched_dot(x, y)
assert isinstance(out.owner.op, BatchedDot)
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
fn = function([x, y], out, mode=mode)
[batched_dot_thunk] = fn.vm.thunks
assert hasattr(batched_dot_thunk, "cthunk")
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
with config.change_flags(blas__ldflags=""):
fn = function([x, y], out, mode=mode)
[batched_dot_thunk] = fn.vm.thunks
assert not hasattr(batched_dot_thunk, "cthunk")
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
def test_batched_tensordot():
rng = np.random.default_rng(unittest_tools.fetch_seed())
first = tensor4("first")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论