提交 f10a6036 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use direct imports in blas.py

上级 efd9f491
......@@ -103,7 +103,7 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import FunctionPrinter, pprint
from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import as_tensor_variable, cast
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.math import dot, tensordot
from pytensor.tensor.shape import specify_broadcastable
......@@ -157,11 +157,11 @@ class Gemv(Op):
return f"{self.__class__.__name__}{{no_inplace}}"
def make_node(self, y, alpha, A, x, beta):
y = ptb.as_tensor_variable(y)
x = ptb.as_tensor_variable(x)
A = ptb.as_tensor_variable(A)
alpha = ptb.as_tensor_variable(alpha)
beta = ptb.as_tensor_variable(beta)
y = as_tensor_variable(y)
x = as_tensor_variable(x)
A = as_tensor_variable(A)
alpha = as_tensor_variable(alpha)
beta = as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError(
"Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
......@@ -257,10 +257,10 @@ class Ger(Op):
return f"{self.__class__.__name__}{{non-destructive}}"
def make_node(self, A, alpha, x, y):
A = ptb.as_tensor_variable(A)
y = ptb.as_tensor_variable(y)
x = ptb.as_tensor_variable(x)
alpha = ptb.as_tensor_variable(alpha)
A = as_tensor_variable(A)
y = as_tensor_variable(y)
x = as_tensor_variable(x)
alpha = as_tensor_variable(alpha)
if not (A.dtype == x.dtype == y.dtype == alpha.dtype):
raise TypeError(
"ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype)
......@@ -859,7 +859,7 @@ class Gemm(GemmRelated):
return rval
def make_node(self, *inputs):
inputs = list(map(ptb.as_tensor_variable, inputs))
inputs = list(map(as_tensor_variable, inputs))
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported")
......@@ -1129,8 +1129,8 @@ class Dot22(GemmRelated):
check_input = False
def make_node(self, x, y):
x = ptb.as_tensor_variable(x)
y = ptb.as_tensor_variable(y)
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported")
......@@ -1322,8 +1322,8 @@ class BatchedDot(COp):
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
def make_node(self, x, y):
x = ptb.as_tensor_variable(x)
y = ptb.as_tensor_variable(y)
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if not (
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
......@@ -1357,7 +1357,7 @@ class BatchedDot(COp):
# Change dtype if needed
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
x, y = ptb.cast(x, dtype), ptb.cast(y, dtype)
x, y = cast(x, dtype), cast(y, dtype)
out = tensor(dtype=dtype, shape=out_shape)
return Apply(self, [x, y], [out])
......@@ -1738,7 +1738,7 @@ def batched_dot(a, b):
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
FutureWarning,
)
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b)
a, b = as_tensor_variable(a), as_tensor_variable(b)
if a.ndim == 0:
raise TypeError("a must have at least one (batch) axis")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论