提交 6e6e7b16 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify as_tensor_variable imports in aesara.tensor.elemwise

上级 a41da3ee
......@@ -20,7 +20,7 @@ from aesara.scalar.basic import Scalar
from aesara.scalar.basic import bool as scalar_bool
from aesara.scalar.basic import identity as scalar_identity
from aesara.scalar.basic import transfer_type, upcast
from aesara.tensor import _get_vector_length
from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor import elemwise_cgen as cgen
from aesara.tensor import get_vector_length
from aesara.tensor.type import (
......@@ -206,7 +206,7 @@ class DimShuffle(ExternalCOp):
super().__init__([self.c_func_file], self.c_func_name)
def make_node(self, _input):
input = aesara.tensor.basic.as_tensor_variable(_input)
input = as_tensor_variable(_input)
ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable:
if len(ib) != len(self.input_broadcastable):
......@@ -279,7 +279,6 @@ class DimShuffle(ExternalCOp):
return self(*eval_points, return_list=True)
def grad(self, inp, grads):
from aesara.tensor.basic import as_tensor_variable
(x,) = inp
(gz,) = grads
......@@ -484,7 +483,7 @@ second dimension
is left-completed to the greatest number of dimensions with 1s
using DimShuffle.
"""
inputs = list(map(aesara.tensor.basic.as_tensor_variable, inputs))
inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_broadcastables, inputs = self.get_output_info(
DimShuffle, *inputs
)
......@@ -1315,8 +1314,6 @@ class CAReduce(COp):
return input_dtype
def make_node(self, input):
from aesara.tensor.basic import as_tensor_variable
input = as_tensor_variable(input)
inp_dims = input.type.ndim
inp_bdcast = input.type.broadcastable
......@@ -1760,7 +1757,7 @@ class CAReduceDtype(CAReduce):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
input = aesara.tensor.basic.as_tensor_variable(input)
input = as_tensor_variable(input)
dtype = self._output_dtype(input.dtype)
acc_dtype = self._acc_dtype(input.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论