Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
458312ee
提交
458312ee
authored
2月 18, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
2月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix docstrings in tests.tensor.utils
上级
9089d1df
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
97 行增加
和
64 行删除
+97
-64
utils.py
tests/tensor/utils.py
+97
-64
没有找到文件。
tests/tensor/utils.py
浏览文件 @
458312ee
...
@@ -110,11 +110,12 @@ def eval_outputs(outputs, ops=(), mode=None):
...
@@ -110,11 +110,12 @@ def eval_outputs(outputs, ops=(), mode=None):
def
get_numeric_subclasses
(
cls
=
np
.
number
,
ignore
=
None
):
def
get_numeric_subclasses
(
cls
=
np
.
number
,
ignore
=
None
):
# Return subclasses of `cls` in the numpy scalar hierarchy.
"""Return subclasses of `cls` in the numpy scalar hierarchy.
#
# We only return subclasses that correspond to unique data types.
We only return subclasses that correspond to unique data types. The
# The hierarchy can be seen here:
hierarchy can be seen here:
# http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
"""
if
ignore
is
None
:
if
ignore
is
None
:
ignore
=
[]
ignore
=
[]
rval
=
[]
rval
=
[]
...
@@ -133,26 +134,32 @@ def get_numeric_subclasses(cls=np.number, ignore=None):
...
@@ -133,26 +134,32 @@ def get_numeric_subclasses(cls=np.number, ignore=None):
def
get_numeric_types
(
def
get_numeric_types
(
with_int
=
True
,
with_float
=
True
,
with_complex
=
False
,
only_aesara_types
=
True
with_int
=
True
,
with_float
=
True
,
with_complex
=
False
,
only_aesara_types
=
True
):
):
# Return numpy numeric data types.
"""Return NumPy numeric data types.
#
# :param with_int: Whether to include integer types.
Parameters
#
----------
# :param with_float: Whether to include floating point types.
with_int
#
Whether to include integer types.
# :param with_complex: Whether to include complex types.
with_float
#
Whether to include floating point types.
# :param only_aesara_types: If True, then numpy numeric data types that are
with_complex
# not supported by Aesara are ignored (i.e. those that are not declared in
Whether to include complex types.
# scalar/basic.py).
only_aesara_types
#
If ``True``, then numpy numeric data types that are not supported by
# :returns: A list of unique data type objects. Note that multiple data types
Aesara are ignored (i.e. those that are not declared in
# may share the same string representation, but can be differentiated through
``scalar/basic.py``).
# their `num` attribute.
#
Returns
# Note that when `only_aesara_types` is True we could simply return the list
-------
# of types defined in the `scalar` module. However with this function we can
A list of unique data type objects. Note that multiple data types may share
# test more unique dtype objects, and in the future we may use it to
the same string representation, but can be differentiated through their
# automatically detect new data types introduced in numpy.
`num` attribute.
Note that when `only_aesara_types` is True we could simply return the list
of types defined in the `scalar` module. However with this function we can
test more unique dtype objects, and in the future we may use it to
automatically detect new data types introduced in numpy.
"""
if
only_aesara_types
:
if
only_aesara_types
:
aesara_types
=
[
d
.
dtype
for
d
in
aesara
.
scalar
.
all_types
]
aesara_types
=
[
d
.
dtype
for
d
in
aesara
.
scalar
.
all_types
]
rval
=
[]
rval
=
[]
...
@@ -186,17 +193,17 @@ def get_numeric_types(
...
@@ -186,17 +193,17 @@ def get_numeric_types(
def
_numpy_checker
(
x
,
y
):
def
_numpy_checker
(
x
,
y
):
# Checks if x.data and y.data have the same contents.
"""Checks if `x.data` and `y.data` have the same contents.
# Used in DualLinker to compare C version with Python version.
Used in `DualLinker` to compare C version with Python version.
"""
x
,
y
=
x
[
0
],
y
[
0
]
x
,
y
=
x
[
0
],
y
[
0
]
if
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
np
.
any
(
np
.
abs
(
x
-
y
)
>
1e-10
):
if
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
np
.
any
(
np
.
abs
(
x
-
y
)
>
1e-10
):
raise
Exception
(
"Output mismatch."
,
{
"performlinker"
:
x
,
"clinker"
:
y
})
raise
Exception
(
"Output mismatch."
,
{
"performlinker"
:
x
,
"clinker"
:
y
})
def
safe_make_node
(
op
,
*
inputs
):
def
safe_make_node
(
op
,
*
inputs
):
# Emulate the behaviour of make_node when op is a function.
"""Emulate the behaviour of `Op.make_node` when `op` is a function."""
#
# Normally op in an instead of the Op class.
node
=
op
(
*
inputs
)
node
=
op
(
*
inputs
)
if
isinstance
(
node
,
list
):
if
isinstance
(
node
,
list
):
return
node
[
0
]
.
owner
return
node
[
0
]
.
owner
...
@@ -205,15 +212,23 @@ def safe_make_node(op, *inputs):
...
@@ -205,15 +212,23 @@ def safe_make_node(op, *inputs):
def
upcast_float16_ufunc
(
fn
):
def
upcast_float16_ufunc
(
fn
):
# Decorator that enforces computation is not done in float16 by NumPy.
"""Decorator that enforces computation is not done in float16 by NumPy.
#
# Some ufuncs in NumPy will compute float values on int8 and uint8
Some ufuncs in NumPy will compute float values on int8 and uint8
# in half-precision (float16), which is not enough, and not compatible
in half-precision (float16), which is not enough, and not compatible
# with the C code.
with the C code.
#
# :param fn: numpy ufunc
Parameters
# :returns: function similar to fn.__call__, computing the same
----------
# value with a minimum floating-point precision of float32
fn
A NumPy ufunc.
Returns
-------
A function similar to `fn.__call__`, computing the same value with a minimum
floating-point precision of float32
"""
def
ret
(
*
args
,
**
kwargs
):
def
ret
(
*
args
,
**
kwargs
):
out_dtype
=
np
.
find_common_type
([
a
.
dtype
for
a
in
args
],
[
np
.
float16
])
out_dtype
=
np
.
find_common_type
([
a
.
dtype
for
a
in
args
],
[
np
.
float16
])
if
out_dtype
==
"float16"
:
if
out_dtype
==
"float16"
:
...
@@ -226,14 +241,22 @@ def upcast_float16_ufunc(fn):
...
@@ -226,14 +241,22 @@ def upcast_float16_ufunc(fn):
def
upcast_int8_nfunc
(
fn
):
def
upcast_int8_nfunc
(
fn
):
# Decorator that upcasts input of dtype int8 to float32.
"""Decorator that upcasts input of dtype int8 to float32.
#
# This is so that floating-point computation is not carried using
This is so that floating-point computation is not carried using
# half-precision (float16), as some NumPy functions do.
half-precision (float16), as some NumPy functions do.
#
# :param fn: function computing a floating-point value from inputs
Parameters
# :returns: function similar to fn, but upcasting its uint8 and int8
----------
# inputs before carrying out the computation.
fn
A function computing a floating-point value from inputs.
Returns
-------
A function similar to fn, but upcasting its uint8 and int8 inputs before
carrying out the computation.
"""
def
ret
(
*
args
,
**
kwargs
):
def
ret
(
*
args
,
**
kwargs
):
args
=
list
(
args
)
args
=
list
(
args
)
for
i
,
a
in
enumerate
(
args
):
for
i
,
a
in
enumerate
(
args
):
...
@@ -332,15 +355,21 @@ def random_of_dtype(shape, dtype, rng=None):
...
@@ -332,15 +355,21 @@ def random_of_dtype(shape, dtype, rng=None):
def
check_floatX
(
inputs
,
rval
):
def
check_floatX
(
inputs
,
rval
):
# :param inputs: Inputs to a function that returned `rval` with these inputs.
"""
#
Parameters
# :param rval: Value returned by a function with inputs set to `inputs`.
----------
#
inputs
# :returns: Either `rval` unchanged, or `rval` cast in float32. The idea is
Inputs to a function that returned `rval` with these inputs.
# that when a numpy function would have returned a float64, Aesara may prefer
rval
# to return a float32 instead when `config.cast_policy` is set to
Value returned by a function with inputs set to `inputs`.
# 'numpy+floatX' and config.floatX to 'float32', and there was no float64
# input.
Returns
-------
Either `rval` unchanged, or `rval` cast in float32. The idea is that when a
numpy function would have returned a float64, Aesara may prefer to return a
float32 instead when `config.cast_policy` is set to ``'numpy+floatX'`` and
`config.floatX` to ``'float32'``, and there was no float64 input.
"""
if
(
if
(
isinstance
(
rval
,
np
.
ndarray
)
isinstance
(
rval
,
np
.
ndarray
)
and
rval
.
dtype
==
"float64"
and
rval
.
dtype
==
"float64"
...
@@ -355,10 +384,11 @@ def check_floatX(inputs, rval):
...
@@ -355,10 +384,11 @@ def check_floatX(inputs, rval):
def
_numpy_true_div
(
x
,
y
):
def
_numpy_true_div
(
x
,
y
):
# Performs true division, and cast the result in the type we expect.
"""Performs true division, and cast the result in the type we expect.
#
# We define that function so we can use it in TrueDivTester.expected,
We define that function so we can use it in `TrueDivTester.expected`,
# because simply calling np.true_divide could cause a dtype mismatch.
because simply calling np.true_divide could cause a dtype mismatch.
"""
out
=
np
.
true_divide
(
x
,
y
)
out
=
np
.
true_divide
(
x
,
y
)
# Use floatX as the result of int / int
# Use floatX as the result of int / int
if
x
.
dtype
in
discrete_dtypes
and
y
.
dtype
in
discrete_dtypes
:
if
x
.
dtype
in
discrete_dtypes
and
y
.
dtype
in
discrete_dtypes
:
...
@@ -367,8 +397,7 @@ def _numpy_true_div(x, y):
...
@@ -367,8 +397,7 @@ def _numpy_true_div(x, y):
def
copymod
(
dct
,
without
=
None
,
**
kwargs
):
def
copymod
(
dct
,
without
=
None
,
**
kwargs
):
# Return dct but with the keys named by args removed, and with
"""Return `dct` but with the keys named by `without` removed, and with `kwargs` added."""
# kwargs added.
if
without
is
None
:
if
without
is
None
:
without
=
[]
without
=
[]
rval
=
copy
(
dct
)
rval
=
copy
(
dct
)
...
@@ -397,8 +426,12 @@ def makeTester(
...
@@ -397,8 +426,12 @@ def makeTester(
check_name
=
False
,
check_name
=
False
,
grad_eps
=
None
,
grad_eps
=
None
,
):
):
# :param check_name:
"""
# Use only for tester that aren't in Aesara.
Parameters
----------
check_name
Use only for testers that aren't in Aesara.
"""
if
checks
is
None
:
if
checks
is
None
:
checks
=
{}
checks
=
{}
if
good
is
None
:
if
good
is
None
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论