提交 90f300ae authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix uses of as_tensor_variable in aesara.tensor.subtensor

上级 51de50be
...@@ -18,7 +18,7 @@ from aesara.graph.utils import MethodNotDefined ...@@ -18,7 +18,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint, set_precedence from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length from aesara.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import ( from aesara.tensor.exceptions import (
...@@ -129,7 +129,7 @@ def as_index_constant( ...@@ -129,7 +129,7 @@ def as_index_constant(
elif isinstance(a, (int, np.integer)): elif isinstance(a, (int, np.integer)):
return aes.ScalarConstant(aes.int64, a) return aes.ScalarConstant(aes.int64, a)
elif not isinstance(a, Variable): elif not isinstance(a, Variable):
return aesara.tensor.as_tensor(a) return as_tensor_variable(a)
else: else:
return a return a
...@@ -687,7 +687,7 @@ class Subtensor(COp): ...@@ -687,7 +687,7 @@ class Subtensor(COp):
A list of aesara Scalars. A list of aesara Scalars.
""" """
x = aesara.tensor.as_tensor_variable(x) x = as_tensor_variable(x)
inputs = tuple(as_nontensor_scalar(a) for a in inputs) inputs = tuple(as_nontensor_scalar(a) for a in inputs)
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
...@@ -1291,8 +1291,8 @@ def inc_subtensor( ...@@ -1291,8 +1291,8 @@ def inc_subtensor(
# First of all, y cannot have a higher dimension than x, # First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable. # nor have non-broadcastable dimensions where x is broadcastable.
x = aesara.tensor.as_tensor_variable(x) x = as_tensor_variable(x)
y = aesara.tensor.as_tensor_variable(y) y = as_tensor_variable(y)
if y.ndim > x.ndim: if y.ndim > x.ndim:
raise TypeError( raise TypeError(
...@@ -1477,7 +1477,7 @@ class IncSubtensor(COp): ...@@ -1477,7 +1477,7 @@ class IncSubtensor(COp):
inputs: TODO WRITEME inputs: TODO WRITEME
""" """
x, y = map(aesara.tensor.as_tensor_variable, [x, y]) x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim: if y.ndim > x.ndim:
raise ValueError( raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional " f"Trying to increment a {int(x.ndim)}-dimensional "
...@@ -1897,8 +1897,8 @@ class AdvancedSubtensor1(COp): ...@@ -1897,8 +1897,8 @@ class AdvancedSubtensor1(COp):
self.sparse_grad = sparse_grad self.sparse_grad = sparse_grad
def make_node(self, x, ilist): def make_node(self, x, ilist):
x_ = aesara.tensor.as_tensor_variable(x) x_ = as_tensor_variable(x)
ilist_ = aesara.tensor.as_tensor_variable(ilist) ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype not in integer_dtypes: if ilist_.type.dtype not in integer_dtypes:
raise TypeError("index must be integers") raise TypeError("index must be integers")
if ilist_.type.ndim != 1: if ilist_.type.ndim != 1:
...@@ -2117,9 +2117,9 @@ class AdvancedIncSubtensor1(COp): ...@@ -2117,9 +2117,9 @@ class AdvancedIncSubtensor1(COp):
return self.__class__.__name__ + "{%s}" % msg return self.__class__.__name__ + "{%s}" % msg
def make_node(self, x, y, ilist): def make_node(self, x, y, ilist):
x_ = aesara.tensor.as_tensor_variable(x) x_ = as_tensor_variable(x)
y_ = aesara.tensor.as_tensor_variable(y) y_ = as_tensor_variable(y)
ilist_ = aesara.tensor.as_tensor_variable(ilist) ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype not in integer_dtypes: if ilist_.type.dtype not in integer_dtypes:
raise TypeError("index must be integers") raise TypeError("index must be integers")
...@@ -2470,7 +2470,7 @@ def as_index_variable(idx): ...@@ -2470,7 +2470,7 @@ def as_index_variable(idx):
return idx return idx
if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT): if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT):
return idx return idx
idx = aesara.tensor.as_tensor_variable(idx) idx = as_tensor_variable(idx)
if idx.type.dtype not in discrete_dtypes: if idx.type.dtype not in discrete_dtypes:
raise TypeError("index must be integers or a boolean mask") raise TypeError("index must be integers or a boolean mask")
return idx return idx
...@@ -2509,7 +2509,7 @@ class AdvancedSubtensor(Op): ...@@ -2509,7 +2509,7 @@ class AdvancedSubtensor(Op):
__props__ = () __props__ = ()
def make_node(self, x, *index): def make_node(self, x, *index):
x = aesara.tensor.as_tensor_variable(x) x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index)) index = tuple(map(as_index_variable, index))
# We only want the broadcast information, and we don't need recursive # We only want the broadcast information, and we don't need recursive
...@@ -2629,13 +2629,13 @@ class AdvancedIncSubtensor(Op): ...@@ -2629,13 +2629,13 @@ class AdvancedIncSubtensor(Op):
) )
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x = aesara.tensor.as_tensor_variable(x) x = as_tensor_variable(x)
y = aesara.tensor.as_tensor_variable(y) y = as_tensor_variable(y)
new_inputs = [] new_inputs = []
for inp in inputs: for inp in inputs:
if isinstance(inp, (list, tuple)): if isinstance(inp, (list, tuple)):
inp = aesara.tensor.as_tensor_variable(inp) inp = as_tensor_variable(inp)
new_inputs.append(inp) new_inputs.append(inp)
return Apply( return Apply(
self, self,
...@@ -2726,8 +2726,8 @@ def take(a, indices, axis=None, mode="raise"): ...@@ -2726,8 +2726,8 @@ def take(a, indices, axis=None, mode="raise"):
input array is used. input array is used.
""" """
a = aesara.tensor.as_tensor_variable(a) a = as_tensor_variable(a)
indices = aesara.tensor.as_tensor_variable(indices) indices = as_tensor_variable(indices)
if not isinstance(axis, (int, type(None))): if not isinstance(axis, (int, type(None))):
raise TypeError("`axis` must be an integer or None") raise TypeError("`axis` must be an integer or None")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论