提交 1704f22a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Misc. refactoring in aesara.tensor.basic

上级 6cca25e3
...@@ -6,7 +6,6 @@ manipulation of tensors. ...@@ -6,7 +6,6 @@ manipulation of tensors.
""" """
import builtins import builtins
import logging
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
...@@ -65,11 +64,6 @@ from aesara.tensor.type import ( ...@@ -65,11 +64,6 @@ from aesara.tensor.type import (
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
_logger = logging.getLogger("aesara.tensor.basic")
__docformat__ = "restructuredtext en"
def __oplist_tag(thing, tag): def __oplist_tag(thing, tag):
tags = getattr(thing, "__oplist_tags", []) tags = getattr(thing, "__oplist_tags", [])
tags.append(tag) tags.append(tag)
...@@ -900,24 +894,14 @@ def cast(x, dtype: Union[str, np.dtype]) -> TensorVariable: ...@@ -900,24 +894,14 @@ def cast(x, dtype: Union[str, np.dtype]) -> TensorVariable:
return _cast_mapping[dtype_name](x) return _cast_mapping[dtype_name](x)
##########################
# Condition
##########################
@scalar_elemwise @scalar_elemwise
def switch(cond, ift, iff): def switch(cond, ift, iff):
"""if cond then ift else iff""" """if cond then ift else iff"""
where = switch where = switch
##########################
# Misc
##########################
# fill, _fill_inplace = _elemwise(aes.second, 'fill',
# """fill WRITEME (elemwise)""")
@scalar_elemwise @scalar_elemwise
def second(a, b): def second(a, b):
"""Create a matrix by filling the shape of a with b""" """Create a matrix by filling the shape of a with b"""
...@@ -1899,11 +1883,6 @@ class Default(Op): ...@@ -1899,11 +1883,6 @@ class Default(Op):
default = Default() default = Default()
##########################
# View Operations
##########################
def extract_constant(x, elemwise=True, only_process_constants=False): def extract_constant(x, elemwise=True, only_process_constants=False):
""" """
This function is basically a call to tensor.get_scalar_constant_value. This function is basically a call to tensor.get_scalar_constant_value.
...@@ -2783,16 +2762,11 @@ def stack(*tensors, **kwargs): ...@@ -2783,16 +2762,11 @@ def stack(*tensors, **kwargs):
# And DebugMode can't detect error in this code as it is not in an # And DebugMode can't detect error in this code as it is not in an
# optimization. # optimization.
# See ticket #660 # See ticket #660
if np.all( if all(
[ # in case there is direct int in tensors. # In case there are explicit ints in tensors
isinstance(t, (np.number, float, int, builtins.complex)) isinstance(t, (np.number, float, int, builtins.complex))
or ( or (isinstance(t, Variable) and isinstance(t.type, TensorType) and t.ndim == 0)
isinstance(t, Variable) for t in tensors
and isinstance(t.type, TensorType)
and t.ndim == 0
)
for t in tensors
]
): ):
# in case there is direct int # in case there is direct int
tensors = list(map(as_tensor_variable, tensors)) tensors = list(map(as_tensor_variable, tensors))
...@@ -4036,7 +4010,7 @@ def swapaxes(y, axis1, axis2): ...@@ -4036,7 +4010,7 @@ def swapaxes(y, axis1, axis2):
return y.dimshuffle(li) return y.dimshuffle(li)
def choose(a, choices, out=None, mode="raise"): def choose(a, choices, mode="raise"):
""" """
Construct an array from an index array and a set of arrays to choose from. Construct an array from an index array and a set of arrays to choose from.
...@@ -4080,9 +4054,6 @@ def choose(a, choices, out=None, mode="raise"): ...@@ -4080,9 +4054,6 @@ def choose(a, choices, out=None, mode="raise"):
the same shape. If choices is itself an array (not recommended), the same shape. If choices is itself an array (not recommended),
then its outermost dimension (i.e., the one corresponding to then its outermost dimension (i.e., the one corresponding to
choices.shape[0]) is taken as defining the ``sequence``. choices.shape[0]) is taken as defining the ``sequence``.
out : array, optional
If provided, the result will be inserted into this array.
It should be of the appropriate shape and dtype.
mode : {``raise`` (default), ``wrap``, ``clip``}, optional mode : {``raise`` (default), ``wrap``, ``clip``}, optional
Specifies how indices outside [0, n-1] will be treated: Specifies how indices outside [0, n-1] will be treated:
``raise`` : an exception is raised ``raise`` : an exception is raised
...@@ -4100,8 +4071,6 @@ def choose(a, choices, out=None, mode="raise"): ...@@ -4100,8 +4071,6 @@ def choose(a, choices, out=None, mode="raise"):
If a and each choice array are not all broadcastable to the same shape. If a and each choice array are not all broadcastable to the same shape.
""" """
# This is done to keep the same function signature then NumPy.
assert out is None
return Choose(mode)(a, choices) return Choose(mode)(a, choices)
...@@ -4365,9 +4334,9 @@ def expand_dims( ...@@ -4365,9 +4334,9 @@ def expand_dims(
def _make_along_axis_idx(arr_shape, indices, axis): def _make_along_axis_idx(arr_shape, indices, axis):
"""Take from `numpy.lib.shape_base`.""" """Take from `numpy.lib.shape_base`."""
# compute dimensions to iterate over
if str(indices.dtype) not in int_dtypes: if str(indices.dtype) not in int_dtypes:
raise IndexError("`indices` must be an integer array") raise IndexError("`indices` must be an integer array")
shape_ones = (1,) * indices.ndim shape_ones = (1,) * indices.ndim
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim)) dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))
......
...@@ -788,13 +788,13 @@ class _tensor_py_operators: ...@@ -788,13 +788,13 @@ class _tensor_py_operators:
"""Fill inputted tensor with the assigned value.""" """Fill inputted tensor with the assigned value."""
return at.basic.fill(self, value) return at.basic.fill(self, value)
def choose(self, choices, out=None, mode="raise"): def choose(self, choices, mode="raise"):
""" """
Construct an array from an index array and a set of arrays to choose Construct an array from an index array and a set of arrays to choose
from. from.
""" """
return at.basic.choose(self, choices, out=None, mode="raise") return at.basic.choose(self, choices, mode="raise")
def squeeze(self): def squeeze(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论