Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c736927b
提交
c736927b
authored
7月 14, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clean up deprecations
This commit introduces the use of module-level `__getattr__` overrides to emit deprecation warnings for renamed objects. It also adds some missing `pytest.deprecated_call` checks.
上级
e0d91807
隐藏空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
287 行增加
和
140 行删除
+287
-140
__init__.py
aesara/__init__.py
+24
-6
configparser.py
aesara/configparser.py
+37
-14
gradient.py
aesara/gradient.py
+30
-6
optdb.py
aesara/graph/optdb.py
+30
-8
scalar.py
aesara/link/numba/dispatch/scalar.py
+7
-5
multinomial.py
aesara/sandbox/multinomial.py
+0
-12
rng_mrg.py
aesara/sandbox/rng_mrg.py
+3
-3
basic.py
aesara/scalar/basic.py
+23
-8
basic.py
aesara/tensor/basic.py
+1
-1
math.py
aesara/tensor/math.py
+25
-10
conv.py
aesara/tensor/nnet/conv.py
+8
-8
slinalg.py
aesara/tensor/slinalg.py
+40
-15
utils.py
aesara/utils.py
+7
-1
test_numba.py
tests/link/test_numba.py
+2
-2
test_multinomial_wo_replacement.py
tests/sandbox/test_multinomial_wo_replacement.py
+2
-2
test_rng_mrg.py
tests/sandbox/test_rng_mrg.py
+17
-7
test_basic.py
tests/tensor/test_basic.py
+1
-8
test_gradient.py
tests/test_gradient.py
+30
-24
没有找到文件。
aesara/__init__.py
浏览文件 @
c736927b
...
@@ -62,12 +62,6 @@ for p in sys.path:
...
@@ -62,12 +62,6 @@ for p in sys.path:
raise
RuntimeError
(
"You have the aesara directory in your Python path."
)
raise
RuntimeError
(
"You have the aesara directory in your Python path."
)
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.utils
import
deprecated
change_flags
=
deprecated
(
"Use aesara.config.change_flags instead!"
)(
config
.
change_flags
)
# This is the api version for ops that generate C code. External ops
# This is the api version for ops that generate C code. External ops
...
@@ -178,3 +172,27 @@ from aesara.scan.views import foldl, foldr, map, reduce
...
@@ -178,3 +172,27 @@ from aesara.scan.views import foldl, foldr, map, reduce
# imports were executed, we can warn about remaining flags provided by the user
# imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS.
# through AESARA_FLAGS.
config
.
warn_unused_flags
()
config
.
warn_unused_flags
()
DEPRECATED_NAMES
=
[
(
"change_flags"
,
"`aesara.change_flags` is deprecated: use `aesara.config.change_flags` instead."
,
config
.
change_flags
,
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
aesara/configparser.py
浏览文件 @
c736927b
...
@@ -14,7 +14,7 @@ from functools import wraps
...
@@ -14,7 +14,7 @@ from functools import wraps
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Callable
,
Dict
,
Optional
,
Sequence
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Sequence
,
Union
from
aesara.utils
import
deprecated
,
hash_from_code
from
aesara.utils
import
hash_from_code
_logger
=
logging
.
getLogger
(
"aesara.configparser"
)
_logger
=
logging
.
getLogger
(
"aesara.configparser"
)
...
@@ -582,8 +582,7 @@ class _ConfigProxy:
...
@@ -582,8 +582,7 @@ class _ConfigProxy:
if
attr
==
"_actual"
:
if
attr
==
"_actual"
:
return
_ConfigProxy
.
_actual
return
_ConfigProxy
.
_actual
warnings
.
warn
(
warnings
.
warn
(
"Accessing config through `aesara.configparser.config` is deprecated. "
"`aesara.configparser.config` is deprecated; use `aesara.config` instead."
,
"Use `aesara.config` instead."
,
DeprecationWarning
,
DeprecationWarning
,
stacklevel
=
2
,
stacklevel
=
2
,
)
)
...
@@ -593,8 +592,7 @@ class _ConfigProxy:
...
@@ -593,8 +592,7 @@ class _ConfigProxy:
if
attr
==
"_actual"
:
if
attr
==
"_actual"
:
return
setattr
(
_ConfigProxy
.
_actual
,
attr
,
value
)
return
setattr
(
_ConfigProxy
.
_actual
,
attr
,
value
)
warnings
.
warn
(
warnings
.
warn
(
"Accessing config through `aesara.configparser.config` is deprecated. "
"`aesara.configparser.config` is deprecated; use `aesara.config` instead."
,
"Use `aesara.config` instead."
,
DeprecationWarning
,
DeprecationWarning
,
stacklevel
=
2
,
stacklevel
=
2
,
)
)
...
@@ -609,12 +607,37 @@ _config = _create_default_config()
...
@@ -609,12 +607,37 @@ _config = _create_default_config()
# These imports/accesses should be replaced with `aesara.config`, so this wraps
# These imports/accesses should be replaced with `aesara.config`, so this wraps
# it with warnings:
# it with warnings:
config
=
_ConfigProxy
(
_config
)
config
=
_ConfigProxy
(
_config
)
# We can't alias the methods of the `config` variable above without already
# triggering the warning. Instead, we wrap the methods of the actual instance
DEPRECATED_NAMES
=
[
# with warnings:
(
change_flags
=
deprecated
(
"Use aesara.config.change_flags instead!"
)(
"change_flags"
,
_config
.
change_flags
"`change_flags` is deprecated; use `aesara.config.change_flags` instead."
,
)
_config
.
change_flags
,
_config_print
=
deprecated
(
"Use aesara.config.config_print instead!"
)(
),
_config
.
config_print
(
)
"_change_flags"
,
"`_change_flags` is deprecated; use `aesara.config.change_flags` instead."
,
_config
.
change_flags
,
),
(
"_config_print"
,
"`_config_print` is deprecated; use `aesara.config.config_print` instead."
,
_config
.
config_print
,
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
aesara/gradient.py
浏览文件 @
c736927b
...
@@ -2129,10 +2129,9 @@ consider_constant_ = ConsiderConstant()
...
@@ -2129,10 +2129,9 @@ consider_constant_ = ConsiderConstant()
def
consider_constant
(
x
):
def
consider_constant
(
x
):
"""
"""Consider an expression constant when computing gradients.
DEPRECATED: use zero_grad() or disconnected_grad() instead.
Consider an expression constant when computing gradients
.
DEPRECATED: use `zero_grad` or `disconnected_grad` instead
.
The expression itself is unaffected, but when its gradient is
The expression itself is unaffected, but when its gradient is
computed, or the gradient of another expression that this
computed, or the gradient of another expression that this
...
@@ -2149,14 +2148,14 @@ def consider_constant(x):
...
@@ -2149,14 +2148,14 @@ def consider_constant(x):
"""
"""
warnings
.
warn
(
warnings
.
warn
(
(
(
"
consider_constant() is deprecated, use zero_grad()
or "
"
`ConsiderConstant` is deprecated; use `zero_grad`
or "
"
disconnected_grad()
instead."
"
`disconnected_grad`
instead."
),
),
category
=
DeprecationWarning
,
category
=
DeprecationWarning
,
stacklevel
=
3
,
stacklevel
=
3
,
)
)
return
consider_constant_
(
x
)
return
ConsiderConstant
()
(
x
)
class
ZeroGrad
(
ViewOp
):
class
ZeroGrad
(
ViewOp
):
...
@@ -2365,3 +2364,28 @@ def grad_scale(x, multiplier):
...
@@ -2365,3 +2364,28 @@ def grad_scale(x, multiplier):
0.416...
0.416...
"""
"""
return
GradScale
(
multiplier
)(
x
)
return
GradScale
(
multiplier
)(
x
)
DEPRECATED_NAMES
=
[
(
"consider_constant_"
,
"`consider_constant_` is deprecated; use `zero_grad` or `disconnected_grad` instead."
,
ConsiderConstant
(),
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
aesara/graph/optdb.py
浏览文件 @
c736927b
...
@@ -177,10 +177,6 @@ class OptimizationDatabase:
...
@@ -177,10 +177,6 @@ class OptimizationDatabase:
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
# This is deprecated and will be removed.
DB
=
OptimizationDatabase
class
OptimizationQuery
:
class
OptimizationQuery
:
"""An object that specifies a set of optimizations by tag/name."""
"""An object that specifies a set of optimizations by tag/name."""
...
@@ -293,10 +289,6 @@ class OptimizationQuery:
...
@@ -293,10 +289,6 @@ class OptimizationQuery:
)
)
# This is deprecated and will be removed.
Query
=
OptimizationQuery
class
EquilibriumDB
(
OptimizationDatabase
):
class
EquilibriumDB
(
OptimizationDatabase
):
"""
"""
A set of potential optimizations which should be applied in an arbitrary
A set of potential optimizations which should be applied in an arbitrary
...
@@ -550,3 +542,33 @@ class ProxyDB(OptimizationDatabase):
...
@@ -550,3 +542,33 @@ class ProxyDB(OptimizationDatabase):
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
return
self
.
db
.
query
(
*
tags
,
**
kwtags
)
return
self
.
db
.
query
(
*
tags
,
**
kwtags
)
DEPRECATED_NAMES
=
[
(
"DB"
,
"`DB` is deprecated; use `OptimizationDatabase` instead."
,
OptimizationDatabase
,
),
(
"Query"
,
"`Query` is deprecated; use `OptimizationQuery` instead."
,
OptimizationQuery
,
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
aesara/link/numba/dispatch/scalar.py
浏览文件 @
c736927b
...
@@ -22,8 +22,8 @@ from aesara.scalar.basic import (
...
@@ -22,8 +22,8 @@ from aesara.scalar.basic import (
Clip
,
Clip
,
Composite
,
Composite
,
Identity
,
Identity
,
Inv
,
Mul
,
Mul
,
Reciprocal
,
ScalarOp
,
ScalarOp
,
Second
,
Second
,
Switch
,
Switch
,
...
@@ -236,13 +236,15 @@ def numba_funcify_Second(op, node, **kwargs):
...
@@ -236,13 +236,15 @@ def numba_funcify_Second(op, node, **kwargs):
return
second
return
second
@numba_funcify.register
(
Inv
)
@numba_funcify.register
(
Reciprocal
)
def
numba_funcify_
Inv
(
op
,
node
,
**
kwargs
):
def
numba_funcify_
Reciprocal
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
inv
(
x
):
def
reciprocal
(
x
):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return
1
/
x
return
1
/
x
return
inv
return
reciprocal
@numba_funcify.register
(
Sigmoid
)
@numba_funcify.register
(
Sigmoid
)
...
...
aesara/sandbox/multinomial.py
浏览文件 @
c736927b
import
copy
import
copy
import
warnings
from
typing
import
Tuple
,
Union
from
typing
import
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -435,14 +434,3 @@ class ChoiceFromUniform(MultinomialFromUniform):
...
@@ -435,14 +434,3 @@ class ChoiceFromUniform(MultinomialFromUniform):
pvals
[
n
,
m
]
=
0.0
pvals
[
n
,
m
]
=
0.0
pvals
[
n
]
/=
pvals
[
n
]
.
sum
()
pvals
[
n
]
/=
pvals
[
n
]
.
sum
()
break
break
class
MultinomialWOReplacementFromUniform
(
ChoiceFromUniform
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
"MultinomialWOReplacementFromUniform is deprecated, "
"use ChoiceFromUniform instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
super
()
.
__init__
(
*
args
,
**
kwargs
)
aesara/sandbox/rng_mrg.py
浏览文件 @
c736927b
...
@@ -1107,10 +1107,10 @@ class MRG_RandomStream:
...
@@ -1107,10 +1107,10 @@ class MRG_RandomStream:
**
kwargs
,
**
kwargs
,
):
):
warnings
.
warn
(
warnings
.
warn
(
"MRG_RandomStream.multinomial_wo_replacement is "
"`MRG_RandomStream.multinomial_wo_replacement` is "
"deprecated and will be removed in the next release of "
"deprecated; use `MRG_RandomStream.choice` instead."
,
"Aesara. Please use MRG_RandomStream.choice instead."
,
DeprecationWarning
,
DeprecationWarning
,
stacklevel
=
2
,
)
)
assert
size
is
None
assert
size
is
None
return
self
.
choice
(
return
self
.
choice
(
...
...
aesara/scalar/basic.py
浏览文件 @
c736927b
...
@@ -670,10 +670,6 @@ class ScalarType(CType, HasDataType, HasShape):
...
@@ -670,10 +670,6 @@ class ScalarType(CType, HasDataType, HasShape):
return
shape_info
return
shape_info
# Deprecated alias for backward compatibility
Scalar
=
ScalarType
def
get_scalar_type
(
dtype
,
cache
:
Dict
[
str
,
ScalarType
]
=
{})
->
ScalarType
:
def
get_scalar_type
(
dtype
,
cache
:
Dict
[
str
,
ScalarType
]
=
{})
->
ScalarType
:
"""
"""
Return a ScalarType(dtype) object.
Return a ScalarType(dtype) object.
...
@@ -2903,10 +2899,6 @@ class Reciprocal(UnaryScalarOp):
...
@@ -2903,10 +2899,6 @@ class Reciprocal(UnaryScalarOp):
reciprocal
=
Reciprocal
(
upgrade_to_float
,
name
=
"reciprocal"
)
reciprocal
=
Reciprocal
(
upgrade_to_float
,
name
=
"reciprocal"
)
# These are deprecated and will be removed
Inv
=
Reciprocal
inv
=
reciprocal
class
Log
(
UnaryScalarOp
):
class
Log
(
UnaryScalarOp
):
"""
"""
...
@@ -4455,3 +4447,26 @@ def handle_composite(node, mapping):
...
@@ -4455,3 +4447,26 @@ def handle_composite(node, mapping):
Compositef32
.
special
[
Composite
]
=
handle_composite
Compositef32
.
special
[
Composite
]
=
handle_composite
DEPRECATED_NAMES
=
[
(
"Inv"
,
"`Inv` is deprecated; use `Reciprocal` instead."
,
Reciprocal
),
(
"inv"
,
"`inv` is deprecated; use `reciprocal` instead."
,
reciprocal
),
(
"Scalar"
,
"`Scalar` is deprecated; use `ScalarType` instead."
,
ScalarType
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
aesara/tensor/basic.py
浏览文件 @
c736927b
...
@@ -2667,7 +2667,7 @@ def is_flat(var, ndim=None, outdim=None):
...
@@ -2667,7 +2667,7 @@ def is_flat(var, ndim=None, outdim=None):
elif
outdim
is
not
None
and
ndim
is
not
None
:
elif
outdim
is
not
None
and
ndim
is
not
None
:
raise
ValueError
(
"You should only specify ndim"
)
raise
ValueError
(
"You should only specify ndim"
)
elif
outdim
is
not
None
:
elif
outdim
is
not
None
:
warnings
.
warn
(
"
flatten outdim parameter is deprecated, use ndim
instead."
)
warnings
.
warn
(
"
outdim` is deprecated; use `ndim`
instead."
)
ndim
=
outdim
ndim
=
outdim
return
var
.
ndim
==
ndim
return
var
.
ndim
==
ndim
...
...
aesara/tensor/math.py
浏览文件 @
c736927b
...
@@ -1048,10 +1048,6 @@ def abs(a):
...
@@ -1048,10 +1048,6 @@ def abs(a):
"""|`a`|"""
"""|`a`|"""
# These are deprecated and will be removed
abs_
=
abs
pprint
.
assign
(
abs
,
printing
.
PatternPrinter
((
"|
%(0)
s|"
,
-
1000
)))
pprint
.
assign
(
abs
,
printing
.
PatternPrinter
((
"|
%(0)
s|"
,
-
1000
)))
...
@@ -1080,10 +1076,6 @@ def reciprocal(a):
...
@@ -1080,10 +1076,6 @@ def reciprocal(a):
"""1.0/a"""
"""1.0/a"""
# This is deprecated and will be removed
inv
=
reciprocal
@scalar_elemwise
@scalar_elemwise
def
log
(
a
):
def
log
(
a
):
"""base e logarithm of a"""
"""base e logarithm of a"""
...
@@ -3024,13 +3016,11 @@ __all__ = [
...
@@ -3024,13 +3016,11 @@ __all__ = [
"invert"
,
"invert"
,
"bitwise_not"
,
"bitwise_not"
,
"abs"
,
"abs"
,
"abs_"
,
"exp"
,
"exp"
,
"exp2"
,
"exp2"
,
"expm1"
,
"expm1"
,
"neg"
,
"neg"
,
"reciprocal"
,
"reciprocal"
,
"inv"
,
"log"
,
"log"
,
"log2"
,
"log2"
,
"log10"
,
"log10"
,
...
@@ -3127,3 +3117,28 @@ __all__ = [
...
@@ -3127,3 +3117,28 @@ __all__ = [
"logaddexp"
,
"logaddexp"
,
"logsumexp"
,
"logsumexp"
,
]
]
DEPRECATED_NAMES
=
[
(
"abs_"
,
"`abs_` is deprecated; use `abs` instead."
,
abs
),
(
"inv"
,
"`inv` is deprecated; use `reciprocal` instead."
,
reciprocal
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
def
__dir__
():
return
sorted
(
__all__
+
[
names
[
0
]
for
names
in
DEPRECATED_NAMES
])
aesara/tensor/nnet/conv.py
浏览文件 @
c736927b
...
@@ -46,12 +46,13 @@ def conv2d(
...
@@ -46,12 +46,13 @@ def conv2d(
subsample
=
(
1
,
1
),
subsample
=
(
1
,
1
),
**
kargs
,
**
kargs
,
):
):
"""
"""Build the symbolic graph for convolving a stack of input images with a set of filters.
Deprecated, old conv2d interface.
This function will build the symbolic graph for convolving a stack of
The implementation is modelled after Convolutional Neural Networks
input images with a set of filters. The implementation is modelled after
(CNN). It is simply a wrapper to the `ConvOp` but provides a much cleaner
Convolutional Neural Networks (CNN). It is simply a wrapper to the ConvOp
interface.
but provides a much cleaner interface.
This is deprecated.
Parameters
Parameters
----------
----------
...
@@ -402,8 +403,7 @@ class ConvOp(OpenMPOp):
...
@@ -402,8 +403,7 @@ class ConvOp(OpenMPOp):
# with s=1 for mode=='full' and s=-1 for mode=='valid'.
# with s=1 for mode=='full' and s=-1 for mode=='valid'.
# To support symbolic shapes, we express this with integer arithmetic.
# To support symbolic shapes, we express this with integer arithmetic.
warnings
.
warn
(
warnings
.
warn
(
"The method `getOutputShape` is deprecated use"
"`getOutputShape` is deprecated; use `get_conv_output_shape` instead."
,
"`get_conv_output_shape` instead."
,
DeprecationWarning
,
DeprecationWarning
,
stacklevel
=
2
,
stacklevel
=
2
,
)
)
...
...
aesara/tensor/slinalg.py
浏览文件 @
c736927b
...
@@ -101,9 +101,8 @@ class Cholesky(Op):
...
@@ -101,9 +101,8 @@ class Cholesky(Op):
def
conjugate_solve_triangular
(
outer
,
inner
):
def
conjugate_solve_triangular
(
outer
,
inner
):
"""Computes L^{-T} P L^{-1} for lower-triangular L."""
"""Computes L^{-T} P L^{-1} for lower-triangular L."""
return
solve_upper_triangular
(
solve_upper
=
SolveTriangular
(
lower
=
False
)
outer
.
T
,
solve_upper_triangular
(
outer
.
T
,
inner
.
T
)
.
T
return
solve_upper
(
outer
.
T
,
solve_upper
(
outer
.
T
,
inner
.
T
)
.
T
)
)
s
=
conjugate_solve_triangular
(
s
=
conjugate_solve_triangular
(
chol_x
,
tril_and_halve_diagonal
(
chol_x
.
T
.
dot
(
dz
))
chol_x
,
tril_and_halve_diagonal
(
chol_x
.
T
.
dot
(
dz
))
...
@@ -507,15 +506,6 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
...
@@ -507,15 +506,6 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
)(
a
,
b
)
)(
a
,
b
)
# TODO: These are deprecated; emit a warning
solve_lower_triangular
=
SolveTriangular
(
lower
=
True
)
solve_upper_triangular
=
SolveTriangular
(
lower
=
False
)
solve_symmetric
=
Solve
(
assume_a
=
"sym"
)
# TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten)
class
Eigvalsh
(
Op
):
class
Eigvalsh
(
Op
):
"""
"""
Generalized eigenvalues of a Hermitian positive definite eigensystem.
Generalized eigenvalues of a Hermitian positive definite eigensystem.
...
@@ -748,10 +738,45 @@ expm = Expm()
...
@@ -748,10 +738,45 @@ expm = Expm()
__all__
=
[
__all__
=
[
"cholesky"
,
"cholesky"
,
"solve"
,
"solve"
,
"solve_lower_triangular"
,
"solve_upper_triangular"
,
"solve_symmetric"
,
"eigvalsh"
,
"eigvalsh"
,
"kron"
,
"kron"
,
"expm"
,
"expm"
,
]
]
DEPRECATED_NAMES
=
[
(
"solve_lower_triangular"
,
"`solve_lower_triangular` is deprecated; use `solve` instead."
,
SolveTriangular
(
lower
=
True
),
),
(
"solve_upper_triangular"
,
"`solve_upper_triangular` is deprecated; use `solve` instead."
,
SolveTriangular
(
lower
=
False
),
),
(
"solve_symmetric"
,
"`solve_symmetric` is deprecated; use `solve` instead."
,
Solve
(
assume_a
=
"sym"
),
),
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from
warnings
import
warn
for
old_name
,
msg
,
old_object
in
DEPRECATED_NAMES
:
if
name
==
old_name
:
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
old_object
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
def
__dir__
():
return
sorted
(
__all__
+
[
names
[
0
]
for
names
in
DEPRECATED_NAMES
])
aesara/utils.py
浏览文件 @
c736927b
...
@@ -158,12 +158,18 @@ def deprecated(message: str = ""):
...
@@ -158,12 +158,18 @@ def deprecated(message: str = ""):
def
decorator_wrapper
(
func
):
def
decorator_wrapper
(
func
):
@wraps
(
func
)
@wraps
(
func
)
def
function_wrapper
(
*
args
,
**
kwargs
):
def
function_wrapper
(
*
args
,
**
kwargs
):
nonlocal
message
current_call_source
=
"|"
.
join
(
current_call_source
=
"|"
.
join
(
traceback
.
format_stack
(
inspect
.
currentframe
())
traceback
.
format_stack
(
inspect
.
currentframe
())
)
)
if
current_call_source
not
in
function_wrapper
.
last_call_source
:
if
current_call_source
not
in
function_wrapper
.
last_call_source
:
if
not
message
:
message
=
f
"Function {func.__name__} is deprecated."
warnings
.
warn
(
warnings
.
warn
(
"Function {} is now deprecated! {}"
.
format
(
func
.
__name__
,
message
)
,
message
,
category
=
DeprecationWarning
,
category
=
DeprecationWarning
,
stacklevel
=
2
,
stacklevel
=
2
,
)
)
...
...
tests/link/test_numba.py
浏览文件 @
c736927b
...
@@ -827,8 +827,8 @@ def test_Cast(v, dtype):
...
@@ -827,8 +827,8 @@ def test_Cast(v, dtype):
(
set_test_value
(
at
.
iscalar
(),
np
.
array
(
10
,
dtype
=
"int32"
)),
aesb
.
float64
),
(
set_test_value
(
at
.
iscalar
(),
np
.
array
(
10
,
dtype
=
"int32"
)),
aesb
.
float64
),
],
],
)
)
def
test_
Inv
(
v
,
dtype
):
def
test_
reciprocal
(
v
,
dtype
):
g
=
aesb
.
inv
(
v
)
g
=
aesb
.
reciprocal
(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
g_fg
,
...
...
tests/sandbox/test_multinomial_wo_replacement.py
浏览文件 @
c736927b
...
@@ -157,7 +157,7 @@ class TestFunction:
...
@@ -157,7 +157,7 @@ class TestFunction:
p
=
fmatrix
()
p
=
fmatrix
()
n
=
iscalar
()
n
=
iscalar
()
with
pytest
.
warns
(
DeprecationWarning
):
with
pytest
.
deprecated_call
(
):
m
=
th_rng
.
multinomial_wo_replacement
(
pvals
=
p
,
n
=
n
)
m
=
th_rng
.
multinomial_wo_replacement
(
pvals
=
p
,
n
=
n
)
f
=
function
([
p
,
n
],
m
,
allow_input_downcast
=
True
)
f
=
function
([
p
,
n
],
m
,
allow_input_downcast
=
True
)
...
@@ -181,7 +181,7 @@ class TestFunction:
...
@@ -181,7 +181,7 @@ class TestFunction:
p
=
fmatrix
()
p
=
fmatrix
()
n
=
iscalar
()
n
=
iscalar
()
with
pytest
.
warns
(
DeprecationWarning
):
with
pytest
.
deprecated_call
(
):
m
=
th_rng
.
multinomial_wo_replacement
(
pvals
=
p
,
n
=
n
)
m
=
th_rng
.
multinomial_wo_replacement
(
pvals
=
p
,
n
=
n
)
f
=
function
([
p
,
n
],
m
,
allow_input_downcast
=
True
)
f
=
function
([
p
,
n
],
m
,
allow_input_downcast
=
True
)
...
...
tests/sandbox/test_rng_mrg.py
浏览文件 @
c736927b
import
contextlib
import
os
import
os
import
sys
import
sys
import
time
import
time
...
@@ -332,12 +333,20 @@ def test_broadcastable():
...
@@ -332,12 +333,20 @@ def test_broadcastable():
# the sizes of them are implicitly defined with "pvals" argument.
# the sizes of them are implicitly defined with "pvals" argument.
if
distribution
in
[
R
.
multinomial
,
R
.
multinomial_wo_replacement
]:
if
distribution
in
[
R
.
multinomial
,
R
.
multinomial_wo_replacement
]:
# check when all dimensions are constant
# check when all dimensions are constant
uu
=
distribution
(
pvals
=
pvals_1
)
context_mgr
=
(
assert
uu
.
broadcastable
==
(
False
,
True
)
pytest
.
deprecated_call
()
if
distribution
==
R
.
multinomial_wo_replacement
else
contextlib
.
suppress
()
)
with
context_mgr
:
uu
=
distribution
(
pvals
=
pvals_1
)
assert
uu
.
broadcastable
==
(
False
,
True
)
# check when some dimensions are aesara variables
# check when some dimensions are aesara variables
uu
=
distribution
(
pvals
=
pvals_2
)
with
context_mgr
:
assert
uu
.
broadcastable
==
(
False
,
True
)
uu
=
distribution
(
pvals
=
pvals_2
)
assert
uu
.
broadcastable
==
(
False
,
True
)
else
:
else
:
# check when all dimensions are constant
# check when all dimensions are constant
uu
=
distribution
(
size
=
size1
)
uu
=
distribution
(
size
=
size1
)
...
@@ -1109,9 +1118,10 @@ def test_target_parameter():
...
@@ -1109,9 +1118,10 @@ def test_target_parameter():
basic_target_parameter_test
(
basic_target_parameter_test
(
srng
.
choice
(
p
=
pvals
.
astype
(
"float32"
),
replace
=
False
,
target
=
"cpu"
)
srng
.
choice
(
p
=
pvals
.
astype
(
"float32"
),
replace
=
False
,
target
=
"cpu"
)
)
)
basic_target_parameter_test
(
with
pytest
.
deprecated_call
():
srng
.
multinomial_wo_replacement
(
pvals
=
pvals
.
astype
(
"float32"
),
target
=
"cpu"
)
basic_target_parameter_test
(
)
srng
.
multinomial_wo_replacement
(
pvals
=
pvals
.
astype
(
"float32"
),
target
=
"cpu"
)
)
@config.change_flags
(
compute_test_value
=
"off"
)
@config.change_flags
(
compute_test_value
=
"off"
)
...
...
tests/tensor/test_basic.py
浏览文件 @
c736927b
...
@@ -1321,16 +1321,9 @@ class TestJoinAndSplit:
...
@@ -1321,16 +1321,9 @@ class TestJoinAndSplit:
def
test_stack_new_interface
(
self
):
def
test_stack_new_interface
(
self
):
# Test the new numpy-like interface: stack(tensors, axis=0).
# Test the new numpy-like interface: stack(tensors, axis=0).
# Testing against old interface
warnings
.
simplefilter
(
"always"
,
DeprecationWarning
)
a
=
imatrix
(
"a"
)
a
=
imatrix
(
"a"
)
b
=
imatrix
(
"b"
)
b
=
imatrix
(
"b"
)
s1
=
stack
(
a
,
b
)
s2
=
stack
([
a
,
b
])
f
=
function
([
a
,
b
],
[
s1
,
s2
],
mode
=
self
.
mode
)
v1
,
v2
=
f
([[
1
,
2
]],
[[
3
,
4
]])
assert
v1
.
shape
==
v2
.
shape
assert
np
.
all
(
v1
==
v2
)
# Testing axis parameter
# Testing axis parameter
s3
=
stack
([
a
,
b
],
1
)
s3
=
stack
([
a
,
b
],
1
)
f
=
function
([
a
,
b
],
s3
,
mode
=
self
.
mode
)
f
=
function
([
a
,
b
],
s3
,
mode
=
self
.
mode
)
...
...
tests/test_gradient.py
浏览文件 @
c736927b
...
@@ -14,8 +14,6 @@ from aesara.gradient import (
...
@@ -14,8 +14,6 @@ from aesara.gradient import (
NullTypeGradError
,
NullTypeGradError
,
Rop
,
Rop
,
UndefinedGrad
,
UndefinedGrad
,
consider_constant
,
consider_constant_
,
disconnected_grad
,
disconnected_grad
,
disconnected_grad_
,
disconnected_grad_
,
grad
,
grad
,
...
@@ -769,37 +767,45 @@ def test_subgraph_grad():
...
@@ -769,37 +767,45 @@ def test_subgraph_grad():
class
TestConsiderConstant
:
class
TestConsiderConstant
:
def
setup_method
(
self
):
self
.
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
())
def
test_op_removed
(
self
):
def
test_op_removed
(
self
):
from
aesara.gradient
import
ConsiderConstant
,
consider_constant
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
y
=
x
*
consider_constant
(
x
)
with
pytest
.
deprecated_call
():
y
=
x
*
consider_constant
(
x
)
f
=
aesara
.
function
([
x
],
y
)
f
=
aesara
.
function
([
x
],
y
)
# need to refer to aesara.consider_constant_ here,
# aesara.consider_constant is a wrapper function!
assert
ConsiderConstant
not
in
[
assert
consider_constant_
not
in
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
type
(
node
.
op
)
for
node
in
f
.
maker
.
fgraph
.
toposort
()
]
def
test_grad
(
self
):
def
test_grad
(
self
):
a
=
np
.
asarray
(
self
.
rng
.
standard_normal
((
5
,
5
)),
dtype
=
config
.
floatX
)
from
aesara.gradient
import
consider_constant
x
=
matrix
(
"x"
)
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
()
)
expressions_gradients
=
[
a
=
np
.
asarray
(
rng
.
standard_normal
((
5
,
5
)),
dtype
=
config
.
floatX
)
(
x
*
consider_constant
(
x
),
x
),
(
x
*
consider_constant
(
exp
(
x
)),
exp
(
x
)),
(
consider_constant
(
x
),
at
.
constant
(
0.0
)),
(
x
**
2
*
consider_constant
(
x
),
2
*
x
**
2
),
]
for
expr
,
expr_grad
in
expressions_gradients
:
x
=
matrix
(
"x"
)
g
=
grad
(
expr
.
sum
(),
x
)
# gradient according to aesara
f
=
aesara
.
function
([
x
],
g
,
on_unused_input
=
"ignore"
)
# desired gradient
f2
=
aesara
.
function
([
x
],
expr_grad
,
on_unused_input
=
"ignore"
)
assert
np
.
allclose
(
f
(
a
),
f2
(
a
))
with
pytest
.
deprecated_call
():
expressions_gradients
=
[
(
x
*
consider_constant
(
x
),
x
),
(
x
*
consider_constant
(
exp
(
x
)),
exp
(
x
)),
(
consider_constant
(
x
),
at
.
constant
(
0.0
)),
(
x
**
2
*
consider_constant
(
x
),
2
*
x
**
2
),
]
for
expr
,
expr_grad
in
expressions_gradients
:
g
=
grad
(
expr
.
sum
(),
x
)
# gradient according to aesara
f
=
aesara
.
function
([
x
],
g
,
on_unused_input
=
"ignore"
)
# desired gradient
f2
=
aesara
.
function
([
x
],
expr_grad
,
on_unused_input
=
"ignore"
)
assert
np
.
allclose
(
f
(
a
),
f2
(
a
))
class
TestZeroGrad
:
class
TestZeroGrad
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论