Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2066065d
提交
2066065d
authored
5月 09, 2022
作者:
Ricardo
提交者:
Brandon T. Willard
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deprecate addbroadcast and patternbroadcast in favor of specify_broadcastable
上级
eb2b9afb
隐藏空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
142 行增加
和
187 行删除
+142
-187
op.py
aesara/scan/op.py
+1
-1
basic.py
aesara/sparse/basic.py
+6
-4
basic.py
aesara/tensor/basic.py
+14
-74
blas.py
aesara/tensor/blas.py
+7
-2
math.py
aesara/tensor/math.py
+12
-5
batchnorm.py
aesara/tensor/nnet/batchnorm.py
+17
-14
conv.py
aesara/tensor/nnet/conv.py
+10
-7
shape.py
aesara/tensor/shape.py
+36
-1
subtensor.py
aesara/tensor/subtensor.py
+4
-4
test_batchnorm.py
tests/tensor/nnet/test_batchnorm.py
+4
-3
test_basic.py
tests/tensor/test_basic.py
+1
-51
test_basic_opt.py
tests/tensor/test_basic_opt.py
+0
-12
test_math.py
tests/tensor/test_math.py
+5
-2
test_opt_uncanonicalize.py
tests/tensor/test_opt_uncanonicalize.py
+3
-3
test_shape.py
tests/tensor/test_shape.py
+19
-1
test_subtensor.py
tests/tensor/test_subtensor.py
+3
-3
没有找到文件。
aesara/scan/op.py
浏览文件 @
2066065d
...
...
@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{
patternbroadcast,unbroadcast,addbroadcast
}."
"{
unbroadcast, specify_broadcastable
}."
)
size
=
min
(
len
(
v1
.
broadcastable
),
len
(
v2
.
broadcastable
))
for
n
,
(
b1
,
b2
)
in
enumerate
(
...
...
aesara/sparse/basic.py
浏览文件 @
2066065d
...
...
@@ -45,7 +45,7 @@ from aesara.tensor.math import (
tanh
,
trunc
,
)
from
aesara.tensor.shape
import
shape
from
aesara.tensor.shape
import
shape
,
specify_broadcastable
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type
import
continuous_dtypes
as
tensor_continuous_dtypes
from
aesara.tensor.type
import
discrete_dtypes
as
tensor_discrete_dtypes
...
...
@@ -1136,7 +1136,9 @@ class SparseFromDense(Op):
(
x
,)
=
inputs
(
gz
,)
=
gout
gx
=
dense_from_sparse
(
gz
)
gx
=
at
.
patternbroadcast
(
gx
,
x
.
broadcastable
)
gx
=
specify_broadcastable
(
gx
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
)
return
(
gx
,)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
...
...
@@ -1900,9 +1902,9 @@ class SpSum(Op):
else
:
ones
=
at
.
ones_like
(
x
)
if
self
.
axis
==
0
:
r
=
at
.
addbroadcast
(
gz
.
dimshuffle
(
"x"
,
0
),
0
)
*
ones
r
=
specify_broadcastable
(
gz
.
dimshuffle
(
"x"
,
0
),
0
)
*
ones
elif
self
.
axis
==
1
:
r
=
at
.
addbroadcast
(
gz
.
dimshuffle
(
0
,
"x"
),
1
)
*
ones
r
=
specify_broadcastable
(
gz
.
dimshuffle
(
0
,
"x"
),
1
)
*
ones
else
:
raise
ValueError
(
"Illegal value for self.axis."
)
r
=
SparseFromDense
(
o_format
)(
r
)
...
...
aesara/tensor/basic.py
浏览文件 @
2066065d
...
...
@@ -10,7 +10,7 @@ import warnings
from
collections.abc
import
Sequence
from
functools
import
partial
from
numbers
import
Number
from
typing
import
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
cast
as
type_cast
import
numpy
as
np
...
...
@@ -49,6 +49,7 @@ from aesara.tensor.shape import (
shape_padleft
,
shape_padright
,
shape_tuple
,
specify_broadcastable
,
)
from
aesara.tensor.type
import
(
TensorType
,
...
...
@@ -622,8 +623,6 @@ class Rebroadcast(COp):
See Also
--------
unbroadcast <aesara.tensor.unbroadcast>
addbroadcast <aesara.tensor.addbroadcast>
patternbroadcast <aesara.tensor.patternbroadcast>
Notes
-----
...
...
@@ -2255,48 +2254,12 @@ class Split(COp):
)
def
addbroadcast
(
x
,
*
axes
):
"""
Make the input broadcastable in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x
=
as_tensor_variable
(
x
)
if
isinstance
(
x
.
type
,
TensorType
)
and
not
any
(
s
is
None
for
s
in
x
.
type
.
shape
):
if
not
set
(
i
for
i
,
b
in
enumerate
(
x
.
broadcastable
)
if
b
)
.
issuperset
(
axes
):
raise
ValueError
(
f
"{x}'s fixed broadcast pattern does not match {axes}"
)
return
x
rval
=
Rebroadcast
(
*
[(
axis
,
True
)
for
axis
in
axes
])(
x
)
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
def
unbroadcast
(
x
,
*
axes
):
"""
Make the input impossible to broadcast in the specified axes.
For example,
add
broadcast(x, 0) will make the first dimension
of x broadcastable. When performing the function, if the length
For example,
un
broadcast(x, 0) will make the first dimension
of x
not
broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
...
...
@@ -2321,34 +2284,6 @@ def unbroadcast(x, *axes):
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
def
patternbroadcast
(
x
:
TensorVariable
,
broadcastable
:
Iterable
[
Union
[
bool
,
int
]]
)
->
TensorVariable
:
"""Make the input adopt a specific broadcasting pattern.
For example, ``patternbroadcast(x, (True, False))`` will make the first
dimension of `x` broadcastable and the second dimension not broadcastable,
so `x` will now be a row.
Parameters
----------
x
Input to re-broadcast.
broadcastable
Truthy values indicating whether or not a dimension should be
broadcastable or not. If the length of `x` along these dimensions is
not ``1``, a `ValueError` will be raised.
"""
x
=
as_tensor_variable
(
x
)
if
x
.
broadcastable
==
broadcastable
:
return
x
rval
=
Rebroadcast
(
*
[(
i
,
broadcastable
[
i
])
for
i
in
range
(
len
(
broadcastable
))])(
x
)
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
class
Join
(
COp
):
r"""
Concatenate multiple `TensorVariable`\s along some axis.
...
...
@@ -2599,7 +2534,12 @@ class Join(COp):
# broadcast. As the grad need to keep the information,
# read it if needed.
split_gz
=
[
patternbroadcast
(
g
,
t
.
broadcastable
)
for
t
,
g
in
zip
(
tens
,
split_gz
)
g
if
g
.
type
.
broadcastable
==
t
.
type
.
broadcastable
else
specify_broadcastable
(
g
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
t
.
type
.
broadcastable
)
if
b
)
)
for
t
,
g
in
zip
(
tens
,
split_gz
)
]
rval
=
rval
+
split_gz
else
:
...
...
@@ -2822,7 +2762,7 @@ def stack(*tensors, **kwargs):
raise
ValueError
(
"No tensor arguments provided"
)
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and
Rebroadcast
s
# It makes the graph simpler, by not adding DimShuffles and
SpecifyShape
s
# This should be an optimization!
# Doing it here make the graph less canonicalized
...
...
@@ -2979,7 +2919,9 @@ def flatten(x, ndim=1):
bcast_kept_dims
=
_x
.
broadcastable
[:
ndim
-
1
]
bcast_new_dim
=
builtins
.
all
(
_x
.
broadcastable
[
ndim
-
1
:])
broadcastable
=
bcast_kept_dims
+
(
bcast_new_dim
,)
x_reshaped
=
addbroadcast
(
x_reshaped
,
*
[
i
for
i
in
range
(
ndim
)
if
broadcastable
[
i
]])
x_reshaped
=
specify_broadcastable
(
x_reshaped
,
*
[
i
for
i
in
range
(
ndim
)
if
broadcastable
[
i
]]
)
return
x_reshaped
...
...
@@ -4253,9 +4195,7 @@ __all__ = [
"stack"
,
"roll"
,
"join"
,
"patternbroadcast"
,
"unbroadcast"
,
"addbroadcast"
,
"split"
,
"transpose"
,
"extract_constant"
,
...
...
aesara/tensor/blas.py
浏览文件 @
2066065d
...
...
@@ -165,6 +165,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
add
,
mul
,
neg
,
sub
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
(
DenseTensorType
,
integer_dtypes
,
...
...
@@ -2552,9 +2553,13 @@ class BatchedDot(COp):
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
xgrad
=
at
.
patternbroadcast
(
xgrad
,
x
.
broadcastable
)
xgrad
=
specify_broadcastable
(
xgrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
)
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
ygrad
=
at
.
patternbroadcast
(
ygrad
,
y
.
broadcastable
)
ygrad
=
specify_broadcastable
(
ygrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
y
.
type
.
broadcastable
)
if
b
)
)
return
xgrad
,
ygrad
...
...
aesara/tensor/math.py
浏览文件 @
2066065d
...
...
@@ -21,7 +21,6 @@ from aesara.tensor.basic import (
cast
,
concatenate
,
constant
,
patternbroadcast
,
stack
,
switch
,
)
...
...
@@ -32,7 +31,7 @@ from aesara.tensor.elemwise import (
Elemwise
,
scalar_elemwise
,
)
from
aesara.tensor.shape
import
shape
from
aesara.tensor.shape
import
shape
,
specify_broadcastable
from
aesara.tensor.type
import
(
DenseTensorType
,
complex_dtypes
,
...
...
@@ -1961,9 +1960,13 @@ class Dot(Op):
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
xgrad
=
patternbroadcast
(
xgrad
,
x
.
broadcastable
)
xgrad
=
specify_broadcastable
(
xgrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
)
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
ygrad
=
patternbroadcast
(
ygrad
,
y
.
broadcastable
)
ygrad
=
specify_broadcastable
(
ygrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
y
.
type
.
broadcastable
)
if
b
)
)
rval
=
xgrad
,
ygrad
...
...
@@ -2178,7 +2181,11 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
out
=
out_reshaped
.
reshape
(
outshape
,
outndim
)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return
patternbroadcast
(
out
,
outbcast
)
if
out
.
type
.
broadcastable
!=
outbcast
:
out
=
specify_broadcastable
(
out
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
outbcast
)
if
b
)
)
return
out
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
...
...
aesara/tensor/nnet/batchnorm.py
浏览文件 @
2066065d
...
...
@@ -12,6 +12,7 @@ from aesara.tensor.basic_opt import register_specialize_device
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
mean
,
prod
,
reciprocal
,
sqrt
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
TensorType
...
...
@@ -241,8 +242,8 @@ def batch_normalization_train(
gamma
=
gamma
.
dimshuffle
(
params_dimshuffle_pattern
)
beta
=
beta
.
dimshuffle
(
params_dimshuffle_pattern
)
else
:
gamma
=
at
.
addbroadcast
(
gamma
,
*
axes
)
beta
=
at
.
addbroadcast
(
beta
,
*
axes
)
gamma
=
specify_broadcastable
(
gamma
,
*
axes
)
beta
=
specify_broadcastable
(
beta
,
*
axes
)
batchnorm_op
=
AbstractBatchNormTrain
(
axes
=
axes
)
...
...
@@ -253,8 +254,8 @@ def batch_normalization_train(
running_mean
=
running_mean
.
dimshuffle
(
params_dimshuffle_pattern
)
running_var
=
running_var
.
dimshuffle
(
params_dimshuffle_pattern
)
else
:
running_mean
=
at
.
addbroadcast
(
running_mean
,
*
axes
)
running_var
=
at
.
addbroadcast
(
running_var
,
*
axes
)
running_mean
=
specify_broadcastable
(
running_mean
,
*
axes
)
running_var
=
specify_broadcastable
(
running_var
,
*
axes
)
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
=
batchnorm_op
(
inputs
,
gamma
,
...
...
@@ -265,12 +266,14 @@ def batch_normalization_train(
running_var
=
running_var
,
)
if
new_running_mean
.
broadcastable
!=
running_mean
.
broadcastable
:
new_running_mean
=
at
.
patternbroadcast
(
new_running_mean
,
running_mean
.
broadcastable
new_running_mean
=
specify_broadcastable
(
new_running_mean
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
running_mean
.
type
.
broadcastable
)
if
b
),
)
if
new_running_var
.
broadcastable
!=
running_var
.
broadcastable
:
new_running_var
=
at
.
patternbroadcast
(
new_running_var
,
running_var
.
broadcastable
new_running_var
=
specify_broadcastable
(
new_running_var
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
running_var
.
type
.
broadcastable
)
if
b
),
)
results
=
(
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
)
else
:
...
...
@@ -331,7 +334,7 @@ def batch_normalization_test(
axes = (0,)
# for spatial normalization
axes = (0,) + tuple(range(2, inputs.ndim))
gamma, beta, mean, var = (at.
addbroadcast
(t, *axes)
gamma, beta, mean, var = (at.
specify_broadcastable
(t, *axes)
for t in (gamma, beta, mean, var))
out = (inputs - mean) * gamma / at.sqrt(var + epsilon) + beta
"""
...
...
@@ -377,10 +380,10 @@ def batch_normalization_test(
mean
=
mean
.
dimshuffle
(
params_dimshuffle_pattern
)
var
=
var
.
dimshuffle
(
params_dimshuffle_pattern
)
else
:
gamma
=
at
.
addbroadcast
(
gamma
,
*
axes
)
beta
=
at
.
addbroadcast
(
beta
,
*
axes
)
mean
=
at
.
addbroadcast
(
mean
,
*
axes
)
var
=
at
.
addbroadcast
(
var
,
*
axes
)
gamma
=
specify_broadcastable
(
gamma
,
*
axes
)
beta
=
specify_broadcastable
(
beta
,
*
axes
)
mean
=
specify_broadcastable
(
mean
,
*
axes
)
var
=
specify_broadcastable
(
var
,
*
axes
)
batchnorm_op
=
AbstractBatchNormInference
(
axes
=
axes
)
return
batchnorm_op
(
inputs
,
gamma
,
beta
,
mean
,
var
,
epsilon
=
epsilon
)
...
...
@@ -609,7 +612,7 @@ class AbstractBatchNormInference(Op):
)
scale
,
bias
,
est_mean
,
est_var
=
(
at
.
addbroadcast
(
t
,
*
axes
)
for
t
in
(
scale
,
bias
,
est_mean
,
est_var
)
specify_broadcastable
(
t
,
*
axes
)
for
t
in
(
scale
,
bias
,
est_mean
,
est_var
)
)
# define helper expressions
...
...
aesara/tensor/nnet/conv.py
浏览文件 @
2066065d
...
...
@@ -26,13 +26,10 @@ import aesara
from
aesara.graph.basic
import
Apply
from
aesara.link.c.op
import
OpenMPOp
from
aesara.tensor
import
blas
from
aesara.tensor.basic
import
(
as_tensor_variable
,
get_scalar_constant_value
,
patternbroadcast
,
)
from
aesara.tensor.basic
import
as_tensor_variable
,
get_scalar_constant_value
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.nnet.abstract_conv
import
get_conv_output_shape
,
get_conv_shape_1axis
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
discrete_dtypes
,
tensor
...
...
@@ -1103,8 +1100,14 @@ class ConvOp(OpenMPOp):
# din and dw should have the same broadcasting pattern as the
# parameters they are the gradient of (resp. inputs and kerns).
din
=
patternbroadcast
(
din
,
inputs
.
broadcastable
)
dw
=
patternbroadcast
(
dw
,
kerns
.
broadcastable
)
if
din
.
type
.
broadcastable
!=
inputs
.
type
.
broadcastable
:
din
=
specify_broadcastable
(
din
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
inputs
.
type
.
broadcastable
)
if
b
)
)
if
dw
.
type
.
broadcastable
!=
kerns
.
type
.
broadcastable
:
dw
=
specify_broadcastable
(
dw
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
kerns
.
type
.
broadcastable
)
if
b
)
)
return
[
din
,
dw
]
def
c_headers
(
self
,
**
kwargs
):
...
...
aesara/tensor/shape.py
浏览文件 @
2066065d
...
...
@@ -12,7 +12,7 @@ from aesara.link.c.op import COp
from
aesara.link.c.params_type
import
ParamsType
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.scalar
import
int32
from
aesara.tensor
import
_get_vector_length
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
get_vector_length
from
aesara.tensor.exceptions
import
NotScalarConstantError
...
...
@@ -891,3 +891,38 @@ register_shape_i_c_code(
"""
,
version
=
3
,
)
def
specify_broadcastable
(
x
,
*
axes
):
"""
Specify the input as being broadcastable in the specified axes.
For example, specify_broadcastable(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x
=
as_tensor_variable
(
x
)
if
not
axes
:
return
x
if
max
(
axes
)
>=
x
.
type
.
ndim
:
raise
ValueError
(
"Trying to specify broadcastable of non-existent dimension"
)
shape_info
=
[
1
if
i
in
axes
else
None
for
i
in
range
(
len
(
x
.
type
.
shape
))]
return
specify_shape
(
x
,
shape_info
)
aesara/tensor/subtensor.py
浏览文件 @
2066065d
...
...
@@ -20,7 +20,7 @@ from aesara.misc.safe_asarray import _asarray
from
aesara.printing
import
Printer
,
pprint
,
set_precedence
from
aesara.scalar.basic
import
ScalarConstant
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
aesara.tensor.basic
import
a
ddbroadcast
,
a
lloc
,
get_scalar_constant_value
from
aesara.tensor.basic
import
alloc
,
get_scalar_constant_value
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.exceptions
import
(
AdvancedIndexingError
,
...
...
@@ -28,7 +28,7 @@ from aesara.tensor.exceptions import (
ShapeError
,
)
from
aesara.tensor.math
import
clip
from
aesara.tensor.shape
import
Reshape
from
aesara.tensor.shape
import
Reshape
,
specify_broadcastable
from
aesara.tensor.type
import
(
TensorType
,
bscalar
,
...
...
@@ -1322,8 +1322,8 @@ def inc_subtensor(
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a
Rebroadcast
Op to make sure it is the case.
y
=
addbroadcast
(
y
,
dim
)
# We insert a
SpecifyShape
Op to make sure it is the case.
y
=
specify_broadcastable
(
y
,
dim
)
if
not
x
.
owner
:
raise
TypeError
(
"x must be the result of a subtensor operation"
)
...
...
tests/tensor/nnet/test_batchnorm.py
浏览文件 @
2066065d
...
...
@@ -8,6 +8,7 @@ import aesara.tensor as at
from
aesara.configdefaults
import
config
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.nnet
import
batchnorm
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
(
TensorType
,
matrix
,
...
...
@@ -219,8 +220,8 @@ def test_batch_normalization_train():
x_mean2
=
x
.
mean
(
axis
=
axes2
,
keepdims
=
True
)
x_var2
=
x
.
var
(
axis
=
axes2
,
keepdims
=
True
)
x_invstd2
=
at
.
reciprocal
(
at
.
sqrt
(
x_var2
+
eps
))
scale2
=
at
.
addbroadcast
(
scale
,
*
axes2
)
bias2
=
at
.
addbroadcast
(
bias
,
*
axes2
)
scale2
=
specify_broadcastable
(
scale
,
*
axes2
)
bias2
=
specify_broadcastable
(
bias
,
*
axes2
)
out2
=
(
x
-
x_mean2
)
*
(
scale2
*
x_invstd2
)
+
bias2
m
=
at
.
cast
(
at
.
prod
(
x
.
shape
)
/
at
.
prod
(
scale
.
shape
),
aesara
.
config
.
floatX
)
out_running_mean2
=
(
...
...
@@ -597,7 +598,7 @@ def test_batch_normalization_test():
else
:
axes2
=
axes
scale2
,
bias2
,
mean2
,
var2
=
(
at
.
addbroadcast
(
t
,
*
axes2
)
for
t
in
(
scale
,
bias
,
mean
,
var
)
specify_broadcastable
(
t
,
*
axes2
)
for
t
in
(
scale
,
bias
,
mean
,
var
)
)
out2
=
(
x
-
mean2
)
*
(
scale2
/
at
.
sqrt
(
var2
+
eps
))
+
bias2
# backward pass
...
...
tests/tensor/test_basic.py
浏览文件 @
2066065d
...
...
@@ -39,7 +39,6 @@ from aesara.tensor.basic import (
Split
,
TensorFromScalar
,
Tri
,
addbroadcast
,
alloc
,
arange
,
as_tensor_variable
,
...
...
@@ -69,7 +68,6 @@ from aesara.tensor.basic import (
nonzero_values
,
ogrid
,
ones_like
,
patternbroadcast
,
permute_row_elements
,
roll
,
scalar_from_tensor
,
...
...
@@ -3226,20 +3224,7 @@ class TestLongTensor:
class
TestBroadcast
:
def
test_addbroadcast_validation
(
self
):
x
=
as_tensor_variable
(
np
.
zeros
((
2
,
3
)))
with
pytest
.
raises
(
ValueError
,
match
=
".*pattern does not.*"
):
addbroadcast
(
x
,
4
)
def
test_broadcast_bigdim
(
self
):
def
f
():
x
=
matrix
()
addbroadcast
(
x
,
2
)
with
pytest
.
raises
(
ValueError
):
f
()
def
test_unbroadcast_addbroadcast
(
self
):
def
test_unbroadcast
(
self
):
# test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op
...
...
@@ -3249,26 +3234,12 @@ class TestBroadcast:
assert
unbroadcast
(
x
,
1
,
0
)
is
x
assert
unbroadcast
(
x
,
0
,
1
)
is
x
assert
addbroadcast
(
x
,
0
)
is
not
x
assert
addbroadcast
(
x
,
1
)
is
not
x
assert
addbroadcast
(
x
,
1
,
0
)
.
owner
.
inputs
[
0
]
is
x
assert
unbroadcast
(
addbroadcast
(
x
,
0
),
0
)
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
0
),
0
)
is
not
x
x
=
row
()
assert
unbroadcast
(
x
,
0
)
is
not
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
not
x
assert
unbroadcast
(
x
,
0
,
1
)
is
not
x
assert
addbroadcast
(
x
,
0
)
is
x
assert
addbroadcast
(
x
,
1
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
x
,
1
,
0
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
x
,
0
,
1
)
.
owner
.
inputs
[
0
]
is
x
assert
unbroadcast
(
addbroadcast
(
x
,
1
),
1
)
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
1
),
1
)
is
not
x
# The first broadcast is remove the broadcast, so the second
# should not make one
assert
unbroadcast
(
unbroadcast
(
x
,
0
),
0
)
.
owner
.
inputs
[
0
]
is
x
...
...
@@ -3276,29 +3247,8 @@ class TestBroadcast:
# Test that consecutive Rebroadcast op are fused
x
=
TensorType
(
dtype
=
"float64"
,
shape
=
(
True
,
True
))()
assert
unbroadcast
(
unbroadcast
(
x
,
1
),
0
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
1
),
0
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
0
),
0
)
is
x
def
test_patternbroadcast
(
self
):
# Test that patternbroadcast with an empty broadcasting pattern works
x
=
scalar
(
"x"
)
m
=
matrix
(
"m"
)
s
=
patternbroadcast
(
m
,
x
.
broadcastable
)
assert
s
is
m
x2
=
patternbroadcast
(
x
,
x
.
broadcastable
)
assert
x2
is
x
def
test_infer_shape
(
self
):
x
=
matrix
()
y
=
addbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
1
,
5
),
dtype
=
config
.
floatX
))
==
[
1
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
MakeVector
)
x
=
matrix
()
y
=
unbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
2066065d
...
...
@@ -1911,18 +1911,6 @@ class TestRebroadcast:
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
def
test_rebroadcast_rebroadcast
(
self
):
mode
=
get_default_mode
()
.
including
(
"canonicalize"
)
m
=
matrix
()
s
=
at
.
addbroadcast
(
m
,
0
,
1
)
v
=
at
.
unbroadcast
(
s
,
1
)
f
=
function
([
m
],
v
,
mode
=
mode
)
f
([[
76
]])
e
=
f
.
maker
.
fgraph
.
toposort
()
rebroadcast_nodes
=
[
n
for
n
in
e
if
isinstance
(
n
.
op
,
Rebroadcast
)]
assert
len
(
rebroadcast_nodes
)
==
1
assert
rebroadcast_nodes
[
0
]
.
op
.
axis
==
{
0
:
True
}
class
TestUselessElemwise
:
def
setup_method
(
self
):
...
...
tests/tensor/test_math.py
浏览文件 @
2066065d
...
...
@@ -1918,6 +1918,9 @@ class TestDot:
# These examples should all work. All dimensions of all results have
# size 1.
#
def
is_super_shape
(
var1
,
var2
):
# Check that var1.type is a superset of var2.type, ignoring dtype
return
var1
.
type
.
is_super
(
var2
.
type
.
clone
(
dtype
=
var1
.
type
.
dtype
))
for
dtype0
in
(
"float32"
,
"float64"
,
"complex64"
):
for
dtype1
in
(
"float32"
,
"complex64"
,
"complex128"
):
...
...
@@ -1944,9 +1947,9 @@ class TestDot:
if
dtype0
.
startswith
(
"float"
)
and
dtype1
.
startswith
(
"float"
):
g
=
grad
(
z
.
sum
(),
x
)
assert
g
.
broadcastable
==
x
.
broadcastable
assert
is_super_shape
(
x
,
g
)
g
=
grad
(
z
.
sum
(),
y
)
assert
g
.
broadcastable
==
y
.
broadcastable
assert
is_super_shape
(
y
,
g
)
class
TestTensordot
:
...
...
tests/tensor/test_opt_uncanonicalize.py
浏览文件 @
2066065d
...
...
@@ -19,7 +19,7 @@ from aesara.tensor.opt_uncanonicalize import (
local_dimshuffle_subtensor
,
local_reshape_dimshuffle
,
)
from
aesara.tensor.shape
import
reshape
from
aesara.tensor.shape
import
reshape
,
specify_shape
from
aesara.tensor.type
import
dtensor4
,
iscalar
,
matrix
,
tensor
,
vector
from
tests.link.test_link
import
make_function
...
...
@@ -179,7 +179,7 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor
=
out2in
(
local_dimshuffle_subtensor
)
x
=
dtensor4
(
"x"
)
x
=
at
.
patternbroadcast
(
x
,
(
False
,
True
,
False
,
Fals
e
))
x
=
specify_shape
(
x
,
(
None
,
1
,
None
,
Non
e
))
i
=
iscalar
(
"i"
)
out
=
x
[:,
:,
10
:
30
,
::
i
]
.
dimshuffle
(
0
,
2
,
3
)
...
...
@@ -213,7 +213,7 @@ def test_local_dimshuffle_subtensor():
# Test a corner case that had Aesara return a bug.
x
=
dtensor4
(
"x"
)
x
=
at
.
patternbroadcast
(
x
,
(
False
,
True
,
False
,
Fals
e
))
x
=
specify_shape
(
x
,
(
None
,
1
,
None
,
Non
e
))
assert
x
[:,
:,
0
:
3
,
::
-
1
]
.
dimshuffle
(
0
,
2
,
3
)
.
eval
(
{
x
:
np
.
ones
((
5
,
1
,
6
,
7
))}
...
...
tests/tensor/test_shape.py
浏览文件 @
2066065d
...
...
@@ -9,7 +9,7 @@ from aesara.graph.basic import Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.type
import
Type
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.tensor
import
as_tensor_variable
,
get_vector_length
from
aesara.tensor
import
as_tensor_variable
,
get_vector_length
,
row
from
aesara.tensor.basic
import
MakeVector
,
constant
from
aesara.tensor.basic_opt
import
ShapeFeature
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
...
...
@@ -21,6 +21,7 @@ from aesara.tensor.shape import (
reshape
,
shape
,
shape_i
,
specify_broadcastable
,
specify_shape
,
)
from
aesara.tensor.subtensor
import
Subtensor
...
...
@@ -518,6 +519,23 @@ class TestSpecifyShape(utt.InferShapeTester):
assert
isinstance
(
z_grad
.
owner
.
op
,
SpecifyShape
)
class
TestSpecifyBroadcastable
:
def
test_basic
(
self
):
x
=
matrix
()
assert
specify_broadcastable
(
x
,
0
)
.
type
.
shape
==
(
1
,
None
)
assert
specify_broadcastable
(
x
,
1
)
.
type
.
shape
==
(
None
,
1
)
assert
specify_broadcastable
(
x
,
0
,
1
)
.
type
.
shape
==
(
1
,
1
)
x
=
row
()
assert
specify_broadcastable
(
x
,
0
)
is
x
assert
specify_broadcastable
(
x
,
1
)
is
not
x
def
test_validation
(
self
):
x
=
matrix
()
with
pytest
.
raises
(
ValueError
,
match
=
"^Trying to specify broadcastable of*"
):
specify_broadcastable
(
x
,
2
)
class
TestRopLop
(
RopLopChecker
):
def
test_shape
(
self
):
self
.
check_nondiff_rop
(
self
.
x
.
shape
[
0
])
...
...
tests/tensor/test_subtensor.py
浏览文件 @
2066065d
...
...
@@ -1501,11 +1501,11 @@ class TestIncSubtensor:
# This one should work
f
(
rng_randX
(
3
,
1
),
rng_randX
(
1
))
# These ones should not
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
f
(
rng_randX
(
3
,
1
),
rng_randX
(
2
))
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
f
(
rng_randX
(
3
,
1
),
rng_randX
(
3
))
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
f
(
rng_randX
(
3
,
1
),
rng_randX
(
0
))
def
test_simple_3d
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论