提交 f10a6036 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use direct imports in blas.py

上级 efd9f491
...@@ -103,7 +103,7 @@ from pytensor.link.c.op import COp ...@@ -103,7 +103,7 @@ from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.printing import FunctionPrinter, pprint from pytensor.printing import FunctionPrinter, pprint
from pytensor.scalar import bool as bool_t from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as ptb from pytensor.tensor.basic import as_tensor_variable, cast
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.math import dot, tensordot from pytensor.tensor.math import dot, tensordot
from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.shape import specify_broadcastable
...@@ -157,11 +157,11 @@ class Gemv(Op): ...@@ -157,11 +157,11 @@ class Gemv(Op):
return f"{self.__class__.__name__}{{no_inplace}}" return f"{self.__class__.__name__}{{no_inplace}}"
def make_node(self, y, alpha, A, x, beta): def make_node(self, y, alpha, A, x, beta):
y = ptb.as_tensor_variable(y) y = as_tensor_variable(y)
x = ptb.as_tensor_variable(x) x = as_tensor_variable(x)
A = ptb.as_tensor_variable(A) A = as_tensor_variable(A)
alpha = ptb.as_tensor_variable(alpha) alpha = as_tensor_variable(alpha)
beta = ptb.as_tensor_variable(beta) beta = as_tensor_variable(beta)
if y.dtype != A.dtype or y.dtype != x.dtype: if y.dtype != A.dtype or y.dtype != x.dtype:
raise TypeError( raise TypeError(
"Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
...@@ -257,10 +257,10 @@ class Ger(Op): ...@@ -257,10 +257,10 @@ class Ger(Op):
return f"{self.__class__.__name__}{{non-destructive}}" return f"{self.__class__.__name__}{{non-destructive}}"
def make_node(self, A, alpha, x, y): def make_node(self, A, alpha, x, y):
A = ptb.as_tensor_variable(A) A = as_tensor_variable(A)
y = ptb.as_tensor_variable(y) y = as_tensor_variable(y)
x = ptb.as_tensor_variable(x) x = as_tensor_variable(x)
alpha = ptb.as_tensor_variable(alpha) alpha = as_tensor_variable(alpha)
if not (A.dtype == x.dtype == y.dtype == alpha.dtype): if not (A.dtype == x.dtype == y.dtype == alpha.dtype):
raise TypeError( raise TypeError(
"ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype) "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype)
...@@ -859,7 +859,7 @@ class Gemm(GemmRelated): ...@@ -859,7 +859,7 @@ class Gemm(GemmRelated):
return rval return rval
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(ptb.as_tensor_variable, inputs)) inputs = list(map(as_tensor_variable, inputs))
if any(not isinstance(i.type, DenseTensorType) for i in inputs): if any(not isinstance(i.type, DenseTensorType) for i in inputs):
raise NotImplementedError("Only dense tensor types are supported") raise NotImplementedError("Only dense tensor types are supported")
...@@ -1129,8 +1129,8 @@ class Dot22(GemmRelated): ...@@ -1129,8 +1129,8 @@ class Dot22(GemmRelated):
check_input = False check_input = False
def make_node(self, x, y): def make_node(self, x, y):
x = ptb.as_tensor_variable(x) x = as_tensor_variable(x)
y = ptb.as_tensor_variable(y) y = as_tensor_variable(y)
if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
raise NotImplementedError("Only dense tensor types are supported") raise NotImplementedError("Only dense tensor types are supported")
...@@ -1322,8 +1322,8 @@ class BatchedDot(COp): ...@@ -1322,8 +1322,8 @@ class BatchedDot(COp):
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
def make_node(self, x, y): def make_node(self, x, y):
x = ptb.as_tensor_variable(x) x = as_tensor_variable(x)
y = ptb.as_tensor_variable(y) y = as_tensor_variable(y)
if not ( if not (
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
...@@ -1357,7 +1357,7 @@ class BatchedDot(COp): ...@@ -1357,7 +1357,7 @@ class BatchedDot(COp):
# Change dtype if needed # Change dtype if needed
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
x, y = ptb.cast(x, dtype), ptb.cast(y, dtype) x, y = cast(x, dtype), cast(y, dtype)
out = tensor(dtype=dtype, shape=out_shape) out = tensor(dtype=dtype, shape=out_shape)
return Apply(self, [x, y], [out]) return Apply(self, [x, y], [out])
...@@ -1738,7 +1738,7 @@ def batched_dot(a, b): ...@@ -1738,7 +1738,7 @@ def batched_dot(a, b):
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`", "Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
FutureWarning, FutureWarning,
) )
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b) a, b = as_tensor_variable(a), as_tensor_variable(b)
if a.ndim == 0: if a.ndim == 0:
raise TypeError("a must have at least one (batch) axis") raise TypeError("a must have at least one (batch) axis")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论