提交 1704f22a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Misc. refactoring in aesara.tensor.basic

上级 6cca25e3
......@@ -6,7 +6,6 @@ manipulation of tensors.
"""
import builtins
import logging
import warnings
from collections.abc import Sequence
from functools import partial
......@@ -65,11 +64,6 @@ from aesara.tensor.type import (
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
_logger = logging.getLogger("aesara.tensor.basic")
__docformat__ = "restructuredtext en"
def __oplist_tag(thing, tag):
tags = getattr(thing, "__oplist_tags", [])
tags.append(tag)
......@@ -900,24 +894,14 @@ def cast(x, dtype: Union[str, np.dtype]) -> TensorVariable:
return _cast_mapping[dtype_name](x)
##########################
# Condition
##########################
@scalar_elemwise
def switch(cond, ift, iff):
"""if cond then ift else iff"""
where = switch
##########################
# Misc
##########################
# fill, _fill_inplace = _elemwise(aes.second, 'fill',
# """fill WRITEME (elemwise)""")
@scalar_elemwise
def second(a, b):
"""Create a matrix by filling the shape of a with b"""
......@@ -1899,11 +1883,6 @@ class Default(Op):
default = Default()
##########################
# View Operations
##########################
def extract_constant(x, elemwise=True, only_process_constants=False):
"""
This function is basically a call to tensor.get_scalar_constant_value.
......@@ -2783,16 +2762,11 @@ def stack(*tensors, **kwargs):
# And DebugMode can't detect error in this code as it is not in an
# optimization.
# See ticket #660
if np.all(
[ # in case there is direct int in tensors.
isinstance(t, (np.number, float, int, builtins.complex))
or (
isinstance(t, Variable)
and isinstance(t.type, TensorType)
and t.ndim == 0
)
for t in tensors
]
if all(
# In case there are explicit ints in tensors
isinstance(t, (np.number, float, int, builtins.complex))
or (isinstance(t, Variable) and isinstance(t.type, TensorType) and t.ndim == 0)
for t in tensors
):
# in case there is direct int
tensors = list(map(as_tensor_variable, tensors))
......@@ -4036,7 +4010,7 @@ def swapaxes(y, axis1, axis2):
return y.dimshuffle(li)
def choose(a, choices, out=None, mode="raise"):
def choose(a, choices, mode="raise"):
"""
Construct an array from an index array and a set of arrays to choose from.
......@@ -4080,9 +4054,6 @@ def choose(a, choices, out=None, mode="raise"):
the same shape. If choices is itself an array (not recommended),
then its outermost dimension (i.e., the one corresponding to
choices.shape[0]) is taken as defining the ``sequence``.
out : array, optional
If provided, the result will be inserted into this array.
It should be of the appropriate shape and dtype.
mode : {``raise`` (default), ``wrap``, ``clip``}, optional
Specifies how indices outside [0, n-1] will be treated:
``raise`` : an exception is raised
......@@ -4100,8 +4071,6 @@ def choose(a, choices, out=None, mode="raise"):
If a and each choice array are not all broadcastable to the same shape.
"""
# This is done to keep the same function signature then NumPy.
assert out is None
return Choose(mode)(a, choices)
......@@ -4365,9 +4334,9 @@ def expand_dims(
def _make_along_axis_idx(arr_shape, indices, axis):
"""Take from `numpy.lib.shape_base`."""
# compute dimensions to iterate over
if str(indices.dtype) not in int_dtypes:
raise IndexError("`indices` must be an integer array")
shape_ones = (1,) * indices.ndim
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))
......
......@@ -788,13 +788,13 @@ class _tensor_py_operators:
"""Fill inputted tensor with the assigned value."""
return at.basic.fill(self, value)
def choose(self, choices, out=None, mode="raise"):
def choose(self, choices, mode="raise"):
"""
Construct an array from an index array and a set of arrays to choose
from.
"""
return at.basic.choose(self, choices, out=None, mode="raise")
return at.basic.choose(self, choices, mode="raise")
def squeeze(self):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论