提交 90f300ae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix uses of as_tensor_variable in aesara.tensor.subtensor

上级 51de50be
......@@ -18,7 +18,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length
from aesara.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import (
......@@ -129,7 +129,7 @@ def as_index_constant(
elif isinstance(a, (int, np.integer)):
return aes.ScalarConstant(aes.int64, a)
elif not isinstance(a, Variable):
return aesara.tensor.as_tensor(a)
return as_tensor_variable(a)
else:
return a
......@@ -687,7 +687,7 @@ class Subtensor(COp):
A list of aesara Scalars.
"""
x = aesara.tensor.as_tensor_variable(x)
x = as_tensor_variable(x)
inputs = tuple(as_nontensor_scalar(a) for a in inputs)
idx_list = list(self.idx_list)
......@@ -1291,8 +1291,8 @@ def inc_subtensor(
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = aesara.tensor.as_tensor_variable(x)
y = aesara.tensor.as_tensor_variable(y)
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise TypeError(
......@@ -1477,7 +1477,7 @@ class IncSubtensor(COp):
inputs: TODO WRITEME
"""
x, y = map(aesara.tensor.as_tensor_variable, [x, y])
x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim:
raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional "
......@@ -1897,8 +1897,8 @@ class AdvancedSubtensor1(COp):
self.sparse_grad = sparse_grad
def make_node(self, x, ilist):
x_ = aesara.tensor.as_tensor_variable(x)
ilist_ = aesara.tensor.as_tensor_variable(ilist)
x_ = as_tensor_variable(x)
ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype not in integer_dtypes:
raise TypeError("index must be integers")
if ilist_.type.ndim != 1:
......@@ -2117,9 +2117,9 @@ class AdvancedIncSubtensor1(COp):
return self.__class__.__name__ + "{%s}" % msg
def make_node(self, x, y, ilist):
x_ = aesara.tensor.as_tensor_variable(x)
y_ = aesara.tensor.as_tensor_variable(y)
ilist_ = aesara.tensor.as_tensor_variable(ilist)
x_ = as_tensor_variable(x)
y_ = as_tensor_variable(y)
ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype not in integer_dtypes:
raise TypeError("index must be integers")
......@@ -2470,7 +2470,7 @@ def as_index_variable(idx):
return idx
if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT):
return idx
idx = aesara.tensor.as_tensor_variable(idx)
idx = as_tensor_variable(idx)
if idx.type.dtype not in discrete_dtypes:
raise TypeError("index must be integers or a boolean mask")
return idx
......@@ -2509,7 +2509,7 @@ class AdvancedSubtensor(Op):
__props__ = ()
def make_node(self, x, *index):
x = aesara.tensor.as_tensor_variable(x)
x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
# We only want the broadcast information, and we don't need recursive
......@@ -2629,13 +2629,13 @@ class AdvancedIncSubtensor(Op):
)
def make_node(self, x, y, *inputs):
x = aesara.tensor.as_tensor_variable(x)
y = aesara.tensor.as_tensor_variable(y)
x = as_tensor_variable(x)
y = as_tensor_variable(y)
new_inputs = []
for inp in inputs:
if isinstance(inp, (list, tuple)):
inp = aesara.tensor.as_tensor_variable(inp)
inp = as_tensor_variable(inp)
new_inputs.append(inp)
return Apply(
self,
......@@ -2726,8 +2726,8 @@ def take(a, indices, axis=None, mode="raise"):
input array is used.
"""
a = aesara.tensor.as_tensor_variable(a)
indices = aesara.tensor.as_tensor_variable(indices)
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
if not isinstance(axis, (int, type(None))):
raise TypeError("`axis` must be an integer or None")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论