提交 6e6e7b16 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify as_tensor_variable imports in aesara.tensor.elemwise

上级 a41da3ee
...@@ -20,7 +20,7 @@ from aesara.scalar.basic import Scalar ...@@ -20,7 +20,7 @@ from aesara.scalar.basic import Scalar
from aesara.scalar.basic import bool as scalar_bool from aesara.scalar.basic import bool as scalar_bool
from aesara.scalar.basic import identity as scalar_identity from aesara.scalar.basic import identity as scalar_identity
from aesara.scalar.basic import transfer_type, upcast from aesara.scalar.basic import transfer_type, upcast
from aesara.tensor import _get_vector_length from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor import elemwise_cgen as cgen from aesara.tensor import elemwise_cgen as cgen
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -206,7 +206,7 @@ class DimShuffle(ExternalCOp): ...@@ -206,7 +206,7 @@ class DimShuffle(ExternalCOp):
super().__init__([self.c_func_file], self.c_func_name) super().__init__([self.c_func_file], self.c_func_name)
def make_node(self, _input): def make_node(self, _input):
input = aesara.tensor.basic.as_tensor_variable(_input) input = as_tensor_variable(_input)
ib = tuple(input.type.broadcastable) ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable: if not ib == self.input_broadcastable:
if len(ib) != len(self.input_broadcastable): if len(ib) != len(self.input_broadcastable):
...@@ -279,7 +279,6 @@ class DimShuffle(ExternalCOp): ...@@ -279,7 +279,6 @@ class DimShuffle(ExternalCOp):
return self(*eval_points, return_list=True) return self(*eval_points, return_list=True)
def grad(self, inp, grads): def grad(self, inp, grads):
from aesara.tensor.basic import as_tensor_variable
(x,) = inp (x,) = inp
(gz,) = grads (gz,) = grads
...@@ -484,7 +483,7 @@ second dimension ...@@ -484,7 +483,7 @@ second dimension
is left-completed to the greatest number of dimensions with 1s is left-completed to the greatest number of dimensions with 1s
using DimShuffle. using DimShuffle.
""" """
inputs = list(map(aesara.tensor.basic.as_tensor_variable, inputs)) inputs = [as_tensor_variable(i) for i in inputs]
out_dtypes, out_broadcastables, inputs = self.get_output_info( out_dtypes, out_broadcastables, inputs = self.get_output_info(
DimShuffle, *inputs DimShuffle, *inputs
) )
...@@ -1315,8 +1314,6 @@ class CAReduce(COp): ...@@ -1315,8 +1314,6 @@ class CAReduce(COp):
return input_dtype return input_dtype
def make_node(self, input): def make_node(self, input):
from aesara.tensor.basic import as_tensor_variable
input = as_tensor_variable(input) input = as_tensor_variable(input)
inp_dims = input.type.ndim inp_dims = input.type.ndim
inp_bdcast = input.type.broadcastable inp_bdcast = input.type.broadcastable
...@@ -1760,7 +1757,7 @@ class CAReduceDtype(CAReduce): ...@@ -1760,7 +1757,7 @@ class CAReduceDtype(CAReduce):
# We need to redefine make_node so that, if self.dtype is None, # We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op # we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype. # of the appropriate dtype.
input = aesara.tensor.basic.as_tensor_variable(input) input = as_tensor_variable(input)
dtype = self._output_dtype(input.dtype) dtype = self._output_dtype(input.dtype)
acc_dtype = self._acc_dtype(input.dtype) acc_dtype = self._acc_dtype(input.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论