提交 458312ee authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix docstrings in tests.tensor.utils

上级 9089d1df
...@@ -110,11 +110,12 @@ def eval_outputs(outputs, ops=(), mode=None): ...@@ -110,11 +110,12 @@ def eval_outputs(outputs, ops=(), mode=None):
def get_numeric_subclasses(cls=np.number, ignore=None): def get_numeric_subclasses(cls=np.number, ignore=None):
# Return subclasses of `cls` in the numpy scalar hierarchy. """Return subclasses of `cls` in the numpy scalar hierarchy.
#
# We only return subclasses that correspond to unique data types. We only return subclasses that correspond to unique data types. The
# The hierarchy can be seen here: hierarchy can be seen here:
# http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
"""
if ignore is None: if ignore is None:
ignore = [] ignore = []
rval = [] rval = []
...@@ -133,26 +134,32 @@ def get_numeric_subclasses(cls=np.number, ignore=None): ...@@ -133,26 +134,32 @@ def get_numeric_subclasses(cls=np.number, ignore=None):
def get_numeric_types( def get_numeric_types(
with_int=True, with_float=True, with_complex=False, only_aesara_types=True with_int=True, with_float=True, with_complex=False, only_aesara_types=True
): ):
# Return numpy numeric data types. """Return NumPy numeric data types.
#
# :param with_int: Whether to include integer types. Parameters
# ----------
# :param with_float: Whether to include floating point types. with_int
# Whether to include integer types.
# :param with_complex: Whether to include complex types. with_float
# Whether to include floating point types.
# :param only_aesara_types: If True, then numpy numeric data types that are with_complex
# not supported by Aesara are ignored (i.e. those that are not declared in Whether to include complex types.
# scalar/basic.py). only_aesara_types
# If ``True``, then numpy numeric data types that are not supported by
# :returns: A list of unique data type objects. Note that multiple data types Aesara are ignored (i.e. those that are not declared in
# may share the same string representation, but can be differentiated through ``scalar/basic.py``).
# their `num` attribute.
# Returns
# Note that when `only_aesara_types` is True we could simply return the list -------
# of types defined in the `scalar` module. However with this function we can A list of unique data type objects. Note that multiple data types may share
# test more unique dtype objects, and in the future we may use it to the same string representation, but can be differentiated through their
# automatically detect new data types introduced in numpy. `num` attribute.
Note that when `only_aesara_types` is True we could simply return the list
of types defined in the `scalar` module. However with this function we can
test more unique dtype objects, and in the future we may use it to
automatically detect new data types introduced in numpy.
"""
if only_aesara_types: if only_aesara_types:
aesara_types = [d.dtype for d in aesara.scalar.all_types] aesara_types = [d.dtype for d in aesara.scalar.all_types]
rval = [] rval = []
...@@ -186,17 +193,17 @@ def get_numeric_types( ...@@ -186,17 +193,17 @@ def get_numeric_types(
def _numpy_checker(x, y): def _numpy_checker(x, y):
# Checks if x.data and y.data have the same contents. """Checks if `x.data` and `y.data` have the same contents.
# Used in DualLinker to compare C version with Python version.
Used in `DualLinker` to compare C version with Python version.
"""
x, y = x[0], y[0] x, y = x[0], y[0]
if x.dtype != y.dtype or x.shape != y.shape or np.any(np.abs(x - y) > 1e-10): if x.dtype != y.dtype or x.shape != y.shape or np.any(np.abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {"performlinker": x, "clinker": y}) raise Exception("Output mismatch.", {"performlinker": x, "clinker": y})
def safe_make_node(op, *inputs): def safe_make_node(op, *inputs):
# Emulate the behaviour of make_node when op is a function. """Emulate the behaviour of `Op.make_node` when `op` is a function."""
#
# Normally op in an instead of the Op class.
node = op(*inputs) node = op(*inputs)
if isinstance(node, list): if isinstance(node, list):
return node[0].owner return node[0].owner
...@@ -205,15 +212,23 @@ def safe_make_node(op, *inputs): ...@@ -205,15 +212,23 @@ def safe_make_node(op, *inputs):
def upcast_float16_ufunc(fn): def upcast_float16_ufunc(fn):
# Decorator that enforces computation is not done in float16 by NumPy. """Decorator that enforces computation is not done in float16 by NumPy.
#
# Some ufuncs in NumPy will compute float values on int8 and uint8 Some ufuncs in NumPy will compute float values on int8 and uint8
# in half-precision (float16), which is not enough, and not compatible in half-precision (float16), which is not enough, and not compatible
# with the C code. with the C code.
#
# :param fn: numpy ufunc Parameters
# :returns: function similar to fn.__call__, computing the same ----------
# value with a minimum floating-point precision of float32 fn
A NumPy ufunc.
Returns
-------
A function similar to `fn.__call__`, computing the same value with a minimum
floating-point precision of float32
"""
def ret(*args, **kwargs): def ret(*args, **kwargs):
out_dtype = np.find_common_type([a.dtype for a in args], [np.float16]) out_dtype = np.find_common_type([a.dtype for a in args], [np.float16])
if out_dtype == "float16": if out_dtype == "float16":
...@@ -226,14 +241,22 @@ def upcast_float16_ufunc(fn): ...@@ -226,14 +241,22 @@ def upcast_float16_ufunc(fn):
def upcast_int8_nfunc(fn): def upcast_int8_nfunc(fn):
# Decorator that upcasts input of dtype int8 to float32. """Decorator that upcasts input of dtype int8 to float32.
#
# This is so that floating-point computation is not carried using This is so that floating-point computation is not carried using
# half-precision (float16), as some NumPy functions do. half-precision (float16), as some NumPy functions do.
#
# :param fn: function computing a floating-point value from inputs Parameters
# :returns: function similar to fn, but upcasting its uint8 and int8 ----------
# inputs before carrying out the computation. fn
A function computing a floating-point value from inputs.
Returns
-------
A function similar to fn, but upcasting its uint8 and int8 inputs before
carrying out the computation.
"""
def ret(*args, **kwargs): def ret(*args, **kwargs):
args = list(args) args = list(args)
for i, a in enumerate(args): for i, a in enumerate(args):
...@@ -332,15 +355,21 @@ def random_of_dtype(shape, dtype, rng=None): ...@@ -332,15 +355,21 @@ def random_of_dtype(shape, dtype, rng=None):
def check_floatX(inputs, rval): def check_floatX(inputs, rval):
# :param inputs: Inputs to a function that returned `rval` with these inputs. """
# Parameters
# :param rval: Value returned by a function with inputs set to `inputs`. ----------
# inputs
# :returns: Either `rval` unchanged, or `rval` cast in float32. The idea is Inputs to a function that returned `rval` with these inputs.
# that when a numpy function would have returned a float64, Aesara may prefer rval
# to return a float32 instead when `config.cast_policy` is set to Value returned by a function with inputs set to `inputs`.
# 'numpy+floatX' and config.floatX to 'float32', and there was no float64
# input. Returns
-------
Either `rval` unchanged, or `rval` cast in float32. The idea is that when a
numpy function would have returned a float64, Aesara may prefer to return a
float32 instead when `config.cast_policy` is set to ``'numpy+floatX'`` and
`config.floatX` to ``'float32'``, and there was no float64 input.
"""
if ( if (
isinstance(rval, np.ndarray) isinstance(rval, np.ndarray)
and rval.dtype == "float64" and rval.dtype == "float64"
...@@ -355,10 +384,11 @@ def check_floatX(inputs, rval): ...@@ -355,10 +384,11 @@ def check_floatX(inputs, rval):
def _numpy_true_div(x, y): def _numpy_true_div(x, y):
# Performs true division, and cast the result in the type we expect. """Performs true division, and cast the result in the type we expect.
#
# We define that function so we can use it in TrueDivTester.expected, We define that function so we can use it in `TrueDivTester.expected`,
# because simply calling np.true_divide could cause a dtype mismatch. because simply calling np.true_divide could cause a dtype mismatch.
"""
out = np.true_divide(x, y) out = np.true_divide(x, y)
# Use floatX as the result of int / int # Use floatX as the result of int / int
if x.dtype in discrete_dtypes and y.dtype in discrete_dtypes: if x.dtype in discrete_dtypes and y.dtype in discrete_dtypes:
...@@ -367,8 +397,7 @@ def _numpy_true_div(x, y): ...@@ -367,8 +397,7 @@ def _numpy_true_div(x, y):
def copymod(dct, without=None, **kwargs): def copymod(dct, without=None, **kwargs):
# Return dct but with the keys named by args removed, and with """Return `dct` but with the keys named by `without` removed, and with `kwargs` added."""
# kwargs added.
if without is None: if without is None:
without = [] without = []
rval = copy(dct) rval = copy(dct)
...@@ -397,8 +426,12 @@ def makeTester( ...@@ -397,8 +426,12 @@ def makeTester(
check_name=False, check_name=False,
grad_eps=None, grad_eps=None,
): ):
# :param check_name: """
# Use only for tester that aren't in Aesara. Parameters
----------
check_name
Use only for testers that aren't in Aesara.
"""
if checks is None: if checks is None:
checks = {} checks = {}
if good is None: if good is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论