提交 458312ee authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix docstrings in tests.tensor.utils

上级 9089d1df
......@@ -110,11 +110,12 @@ def eval_outputs(outputs, ops=(), mode=None):
def get_numeric_subclasses(cls=np.number, ignore=None):
# Return subclasses of `cls` in the numpy scalar hierarchy.
#
# We only return subclasses that correspond to unique data types.
# The hierarchy can be seen here:
# http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
"""Return subclasses of `cls` in the numpy scalar hierarchy.
We only return subclasses that correspond to unique data types. The
hierarchy can be seen here:
http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
"""
if ignore is None:
ignore = []
rval = []
......@@ -133,26 +134,32 @@ def get_numeric_subclasses(cls=np.number, ignore=None):
def get_numeric_types(
with_int=True, with_float=True, with_complex=False, only_aesara_types=True
):
# Return numpy numeric data types.
#
# :param with_int: Whether to include integer types.
#
# :param with_float: Whether to include floating point types.
#
# :param with_complex: Whether to include complex types.
#
# :param only_aesara_types: If True, then numpy numeric data types that are
# not supported by Aesara are ignored (i.e. those that are not declared in
# scalar/basic.py).
#
# :returns: A list of unique data type objects. Note that multiple data types
# may share the same string representation, but can be differentiated through
# their `num` attribute.
#
# Note that when `only_aesara_types` is True we could simply return the list
# of types defined in the `scalar` module. However with this function we can
# test more unique dtype objects, and in the future we may use it to
# automatically detect new data types introduced in numpy.
"""Return NumPy numeric data types.
Parameters
----------
with_int
Whether to include integer types.
with_float
Whether to include floating point types.
with_complex
Whether to include complex types.
only_aesara_types
If ``True``, then numpy numeric data types that are not supported by
Aesara are ignored (i.e. those that are not declared in
``scalar/basic.py``).
Returns
-------
A list of unique data type objects. Note that multiple data types may share
the same string representation, but can be differentiated through their
`num` attribute.
Note that when `only_aesara_types` is True we could simply return the list
of types defined in the `scalar` module. However with this function we can
test more unique dtype objects, and in the future we may use it to
automatically detect new data types introduced in numpy.
"""
if only_aesara_types:
aesara_types = [d.dtype for d in aesara.scalar.all_types]
rval = []
......@@ -186,17 +193,17 @@ def get_numeric_types(
def _numpy_checker(x, y):
# Checks if x.data and y.data have the same contents.
# Used in DualLinker to compare C version with Python version.
"""Checks if `x.data` and `y.data` have the same contents.
Used in `DualLinker` to compare C version with Python version.
"""
x, y = x[0], y[0]
if x.dtype != y.dtype or x.shape != y.shape or np.any(np.abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {"performlinker": x, "clinker": y})
def safe_make_node(op, *inputs):
# Emulate the behaviour of make_node when op is a function.
#
# Normally op in an instead of the Op class.
"""Emulate the behaviour of `Op.make_node` when `op` is a function."""
node = op(*inputs)
if isinstance(node, list):
return node[0].owner
......@@ -205,15 +212,23 @@ def safe_make_node(op, *inputs):
def upcast_float16_ufunc(fn):
# Decorator that enforces computation is not done in float16 by NumPy.
#
# Some ufuncs in NumPy will compute float values on int8 and uint8
# in half-precision (float16), which is not enough, and not compatible
# with the C code.
#
# :param fn: numpy ufunc
# :returns: function similar to fn.__call__, computing the same
# value with a minimum floating-point precision of float32
"""Decorator that enforces computation is not done in float16 by NumPy.
Some ufuncs in NumPy will compute float values on int8 and uint8
in half-precision (float16), which is not enough, and not compatible
with the C code.
Parameters
----------
fn
A NumPy ufunc.
Returns
-------
A function similar to `fn.__call__`, computing the same value with a minimum
floating-point precision of float32
"""
def ret(*args, **kwargs):
out_dtype = np.find_common_type([a.dtype for a in args], [np.float16])
if out_dtype == "float16":
......@@ -226,14 +241,22 @@ def upcast_float16_ufunc(fn):
def upcast_int8_nfunc(fn):
# Decorator that upcasts input of dtype int8 to float32.
#
# This is so that floating-point computation is not carried using
# half-precision (float16), as some NumPy functions do.
#
# :param fn: function computing a floating-point value from inputs
# :returns: function similar to fn, but upcasting its uint8 and int8
# inputs before carrying out the computation.
"""Decorator that upcasts input of dtype int8 to float32.
This is so that floating-point computation is not carried using
half-precision (float16), as some NumPy functions do.
Parameters
----------
fn
A function computing a floating-point value from inputs.
Returns
-------
A function similar to fn, but upcasting its uint8 and int8 inputs before
carrying out the computation.
"""
def ret(*args, **kwargs):
args = list(args)
for i, a in enumerate(args):
......@@ -332,15 +355,21 @@ def random_of_dtype(shape, dtype, rng=None):
def check_floatX(inputs, rval):
# :param inputs: Inputs to a function that returned `rval` with these inputs.
#
# :param rval: Value returned by a function with inputs set to `inputs`.
#
# :returns: Either `rval` unchanged, or `rval` cast in float32. The idea is
# that when a numpy function would have returned a float64, Aesara may prefer
# to return a float32 instead when `config.cast_policy` is set to
# 'numpy+floatX' and config.floatX to 'float32', and there was no float64
# input.
"""
Parameters
----------
inputs
Inputs to a function that returned `rval` with these inputs.
rval
Value returned by a function with inputs set to `inputs`.
Returns
-------
Either `rval` unchanged, or `rval` cast in float32. The idea is that when a
numpy function would have returned a float64, Aesara may prefer to return a
float32 instead when `config.cast_policy` is set to ``'numpy+floatX'`` and
`config.floatX` to ``'float32'``, and there was no float64 input.
"""
if (
isinstance(rval, np.ndarray)
and rval.dtype == "float64"
......@@ -355,10 +384,11 @@ def check_floatX(inputs, rval):
def _numpy_true_div(x, y):
# Performs true division, and cast the result in the type we expect.
#
# We define that function so we can use it in TrueDivTester.expected,
# because simply calling np.true_divide could cause a dtype mismatch.
"""Performs true division, and cast the result in the type we expect.
We define that function so we can use it in `TrueDivTester.expected`,
because simply calling np.true_divide could cause a dtype mismatch.
"""
out = np.true_divide(x, y)
# Use floatX as the result of int / int
if x.dtype in discrete_dtypes and y.dtype in discrete_dtypes:
......@@ -367,8 +397,7 @@ def _numpy_true_div(x, y):
def copymod(dct, without=None, **kwargs):
# Return dct but with the keys named by args removed, and with
# kwargs added.
"""Return `dct` but with the keys named by `without` removed, and with `kwargs` added."""
if without is None:
without = []
rval = copy(dct)
......@@ -397,8 +426,12 @@ def makeTester(
check_name=False,
grad_eps=None,
):
# :param check_name:
# Use only for tester that aren't in Aesara.
"""
Parameters
----------
check_name
Use only for testers that aren't in Aesara.
"""
if checks is None:
checks = {}
if good is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论