Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2066065d
提交
2066065d
authored
5月 09, 2022
作者:
Ricardo
提交者:
Brandon T. Willard
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deprecate addbroadcast and patternbroadcast in favor of specify_broadcastable
上级
eb2b9afb
隐藏空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
142 行增加
和
187 行删除
+142
-187
op.py
aesara/scan/op.py
+1
-1
basic.py
aesara/sparse/basic.py
+6
-4
basic.py
aesara/tensor/basic.py
+14
-74
blas.py
aesara/tensor/blas.py
+7
-2
math.py
aesara/tensor/math.py
+12
-5
batchnorm.py
aesara/tensor/nnet/batchnorm.py
+17
-14
conv.py
aesara/tensor/nnet/conv.py
+10
-7
shape.py
aesara/tensor/shape.py
+36
-1
subtensor.py
aesara/tensor/subtensor.py
+4
-4
test_batchnorm.py
tests/tensor/nnet/test_batchnorm.py
+4
-3
test_basic.py
tests/tensor/test_basic.py
+1
-51
test_basic_opt.py
tests/tensor/test_basic_opt.py
+0
-12
test_math.py
tests/tensor/test_math.py
+5
-2
test_opt_uncanonicalize.py
tests/tensor/test_opt_uncanonicalize.py
+3
-3
test_shape.py
tests/tensor/test_shape.py
+19
-1
test_subtensor.py
tests/tensor/test_subtensor.py
+3
-3
没有找到文件。
aesara/scan/op.py
浏览文件 @
2066065d
...
@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
...
@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
"dimension is fixed to 1 in the input, while it is still "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"them consistent, e.g. using aesara.tensor."
"{
patternbroadcast,unbroadcast,addbroadcast
}."
"{
unbroadcast, specify_broadcastable
}."
)
)
size
=
min
(
len
(
v1
.
broadcastable
),
len
(
v2
.
broadcastable
))
size
=
min
(
len
(
v1
.
broadcastable
),
len
(
v2
.
broadcastable
))
for
n
,
(
b1
,
b2
)
in
enumerate
(
for
n
,
(
b1
,
b2
)
in
enumerate
(
...
...
aesara/sparse/basic.py
浏览文件 @
2066065d
...
@@ -45,7 +45,7 @@ from aesara.tensor.math import (
...
@@ -45,7 +45,7 @@ from aesara.tensor.math import (
tanh
,
tanh
,
trunc
,
trunc
,
)
)
from
aesara.tensor.shape
import
shape
from
aesara.tensor.shape
import
shape
,
specify_broadcastable
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type
import
continuous_dtypes
as
tensor_continuous_dtypes
from
aesara.tensor.type
import
continuous_dtypes
as
tensor_continuous_dtypes
from
aesara.tensor.type
import
discrete_dtypes
as
tensor_discrete_dtypes
from
aesara.tensor.type
import
discrete_dtypes
as
tensor_discrete_dtypes
...
@@ -1136,7 +1136,9 @@ class SparseFromDense(Op):
...
@@ -1136,7 +1136,9 @@ class SparseFromDense(Op):
(
x
,)
=
inputs
(
x
,)
=
inputs
(
gz
,)
=
gout
(
gz
,)
=
gout
gx
=
dense_from_sparse
(
gz
)
gx
=
dense_from_sparse
(
gz
)
gx
=
at
.
patternbroadcast
(
gx
,
x
.
broadcastable
)
gx
=
specify_broadcastable
(
gx
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
)
return
(
gx
,)
return
(
gx
,)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
...
@@ -1900,9 +1902,9 @@ class SpSum(Op):
...
@@ -1900,9 +1902,9 @@ class SpSum(Op):
else
:
else
:
ones
=
at
.
ones_like
(
x
)
ones
=
at
.
ones_like
(
x
)
if
self
.
axis
==
0
:
if
self
.
axis
==
0
:
r
=
at
.
addbroadcast
(
gz
.
dimshuffle
(
"x"
,
0
),
0
)
*
ones
r
=
specify_broadcastable
(
gz
.
dimshuffle
(
"x"
,
0
),
0
)
*
ones
elif
self
.
axis
==
1
:
elif
self
.
axis
==
1
:
r
=
at
.
addbroadcast
(
gz
.
dimshuffle
(
0
,
"x"
),
1
)
*
ones
r
=
specify_broadcastable
(
gz
.
dimshuffle
(
0
,
"x"
),
1
)
*
ones
else
:
else
:
raise
ValueError
(
"Illegal value for self.axis."
)
raise
ValueError
(
"Illegal value for self.axis."
)
r
=
SparseFromDense
(
o_format
)(
r
)
r
=
SparseFromDense
(
o_format
)(
r
)
...
...
aesara/tensor/basic.py
浏览文件 @
2066065d
...
@@ -10,7 +10,7 @@ import warnings
...
@@ -10,7 +10,7 @@ import warnings
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
functools
import
partial
from
functools
import
partial
from
numbers
import
Number
from
numbers
import
Number
from
typing
import
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
cast
as
type_cast
from
typing
import
cast
as
type_cast
import
numpy
as
np
import
numpy
as
np
...
@@ -49,6 +49,7 @@ from aesara.tensor.shape import (
...
@@ -49,6 +49,7 @@ from aesara.tensor.shape import (
shape_padleft
,
shape_padleft
,
shape_padright
,
shape_padright
,
shape_tuple
,
shape_tuple
,
specify_broadcastable
,
)
)
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
TensorType
,
TensorType
,
...
@@ -622,8 +623,6 @@ class Rebroadcast(COp):
...
@@ -622,8 +623,6 @@ class Rebroadcast(COp):
See Also
See Also
--------
--------
unbroadcast <aesara.tensor.unbroadcast>
unbroadcast <aesara.tensor.unbroadcast>
addbroadcast <aesara.tensor.addbroadcast>
patternbroadcast <aesara.tensor.patternbroadcast>
Notes
Notes
-----
-----
...
@@ -2255,48 +2254,12 @@ class Split(COp):
...
@@ -2255,48 +2254,12 @@ class Split(COp):
)
)
def
addbroadcast
(
x
,
*
axes
):
"""
Make the input broadcastable in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x
=
as_tensor_variable
(
x
)
if
isinstance
(
x
.
type
,
TensorType
)
and
not
any
(
s
is
None
for
s
in
x
.
type
.
shape
):
if
not
set
(
i
for
i
,
b
in
enumerate
(
x
.
broadcastable
)
if
b
)
.
issuperset
(
axes
):
raise
ValueError
(
f
"{x}'s fixed broadcast pattern does not match {axes}"
)
return
x
rval
=
Rebroadcast
(
*
[(
axis
,
True
)
for
axis
in
axes
])(
x
)
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
def
unbroadcast
(
x
,
*
axes
):
def
unbroadcast
(
x
,
*
axes
):
"""
"""
Make the input impossible to broadcast in the specified axes.
Make the input impossible to broadcast in the specified axes.
For example,
add
broadcast(x, 0) will make the first dimension
For example,
un
broadcast(x, 0) will make the first dimension
of x broadcastable. When performing the function, if the length
of x
not
broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
We apply the opt here not to pollute the graph
...
@@ -2321,34 +2284,6 @@ def unbroadcast(x, *axes):
...
@@ -2321,34 +2284,6 @@ def unbroadcast(x, *axes):
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
def
patternbroadcast
(
x
:
TensorVariable
,
broadcastable
:
Iterable
[
Union
[
bool
,
int
]]
)
->
TensorVariable
:
"""Make the input adopt a specific broadcasting pattern.
For example, ``patternbroadcast(x, (True, False))`` will make the first
dimension of `x` broadcastable and the second dimension not broadcastable,
so `x` will now be a row.
Parameters
----------
x
Input to re-broadcast.
broadcastable
Truthy values indicating whether or not a dimension should be
broadcastable or not. If the length of `x` along these dimensions is
not ``1``, a `ValueError` will be raised.
"""
x
=
as_tensor_variable
(
x
)
if
x
.
broadcastable
==
broadcastable
:
return
x
rval
=
Rebroadcast
(
*
[(
i
,
broadcastable
[
i
])
for
i
in
range
(
len
(
broadcastable
))])(
x
)
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
class
Join
(
COp
):
class
Join
(
COp
):
r"""
r"""
Concatenate multiple `TensorVariable`\s along some axis.
Concatenate multiple `TensorVariable`\s along some axis.
...
@@ -2599,7 +2534,12 @@ class Join(COp):
...
@@ -2599,7 +2534,12 @@ class Join(COp):
# broadcast. As the grad need to keep the information,
# broadcast. As the grad need to keep the information,
# read it if needed.
# read it if needed.
split_gz
=
[
split_gz
=
[
patternbroadcast
(
g
,
t
.
broadcastable
)
for
t
,
g
in
zip
(
tens
,
split_gz
)
g
if
g
.
type
.
broadcastable
==
t
.
type
.
broadcastable
else
specify_broadcastable
(
g
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
t
.
type
.
broadcastable
)
if
b
)
)
for
t
,
g
in
zip
(
tens
,
split_gz
)
]
]
rval
=
rval
+
split_gz
rval
=
rval
+
split_gz
else
:
else
:
...
@@ -2822,7 +2762,7 @@ def stack(*tensors, **kwargs):
...
@@ -2822,7 +2762,7 @@ def stack(*tensors, **kwargs):
raise
ValueError
(
"No tensor arguments provided"
)
raise
ValueError
(
"No tensor arguments provided"
)
# If all tensors are scalars of the same type, call make_vector.
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and
Rebroadcast
s
# It makes the graph simpler, by not adding DimShuffles and
SpecifyShape
s
# This should be an optimization!
# This should be an optimization!
# Doing it here make the graph less canonicalized
# Doing it here make the graph less canonicalized
...
@@ -2979,7 +2919,9 @@ def flatten(x, ndim=1):
...
@@ -2979,7 +2919,9 @@ def flatten(x, ndim=1):
bcast_kept_dims
=
_x
.
broadcastable
[:
ndim
-
1
]
bcast_kept_dims
=
_x
.
broadcastable
[:
ndim
-
1
]
bcast_new_dim
=
builtins
.
all
(
_x
.
broadcastable
[
ndim
-
1
:])
bcast_new_dim
=
builtins
.
all
(
_x
.
broadcastable
[
ndim
-
1
:])
broadcastable
=
bcast_kept_dims
+
(
bcast_new_dim
,)
broadcastable
=
bcast_kept_dims
+
(
bcast_new_dim
,)
x_reshaped
=
addbroadcast
(
x_reshaped
,
*
[
i
for
i
in
range
(
ndim
)
if
broadcastable
[
i
]])
x_reshaped
=
specify_broadcastable
(
x_reshaped
,
*
[
i
for
i
in
range
(
ndim
)
if
broadcastable
[
i
]]
)
return
x_reshaped
return
x_reshaped
...
@@ -4253,9 +4195,7 @@ __all__ = [
...
@@ -4253,9 +4195,7 @@ __all__ = [
"stack"
,
"stack"
,
"roll"
,
"roll"
,
"join"
,
"join"
,
"patternbroadcast"
,
"unbroadcast"
,
"unbroadcast"
,
"addbroadcast"
,
"split"
,
"split"
,
"transpose"
,
"transpose"
,
"extract_constant"
,
"extract_constant"
,
...
...
aesara/tensor/blas.py
浏览文件 @
2066065d
...
@@ -165,6 +165,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
...
@@ -165,6 +165,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
add
,
mul
,
neg
,
sub
from
aesara.tensor.math
import
Dot
,
add
,
mul
,
neg
,
sub
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
integer_dtypes
,
integer_dtypes
,
...
@@ -2552,9 +2553,13 @@ class BatchedDot(COp):
...
@@ -2552,9 +2553,13 @@ class BatchedDot(COp):
# above code don't always return the right broadcast pattern.
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
# This cause problem down the road. See gh-1461.
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
xgrad
=
at
.
patternbroadcast
(
xgrad
,
x
.
broadcastable
)
xgrad
=
specify_broadcastable
(
xgrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
)
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
ygrad
=
at
.
patternbroadcast
(
ygrad
,
y
.
broadcastable
)
ygrad
=
specify_broadcastable
(
ygrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
y
.
type
.
broadcastable
)
if
b
)
)
return
xgrad
,
ygrad
return
xgrad
,
ygrad
...
...
aesara/tensor/math.py
浏览文件 @
2066065d
...
@@ -21,7 +21,6 @@ from aesara.tensor.basic import (
...
@@ -21,7 +21,6 @@ from aesara.tensor.basic import (
cast
,
cast
,
concatenate
,
concatenate
,
constant
,
constant
,
patternbroadcast
,
stack
,
stack
,
switch
,
switch
,
)
)
...
@@ -32,7 +31,7 @@ from aesara.tensor.elemwise import (
...
@@ -32,7 +31,7 @@ from aesara.tensor.elemwise import (
Elemwise
,
Elemwise
,
scalar_elemwise
,
scalar_elemwise
,
)
)
from
aesara.tensor.shape
import
shape
from
aesara.tensor.shape
import
shape
,
specify_broadcastable
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
DenseTensorType
,
DenseTensorType
,
complex_dtypes
,
complex_dtypes
,
...
@@ -1961,9 +1960,13 @@ class Dot(Op):
...
@@ -1961,9 +1960,13 @@ class Dot(Op):
# above code don't always return the right broadcast pattern.
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
# This cause problem down the road. See gh-1461.
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
if
xgrad
.
broadcastable
!=
x
.
broadcastable
:
xgrad
=
patternbroadcast
(
xgrad
,
x
.
broadcastable
)
xgrad
=
specify_broadcastable
(
xgrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
x
.
type
.
broadcastable
)
if
b
)
)
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
if
ygrad
.
broadcastable
!=
y
.
broadcastable
:
ygrad
=
patternbroadcast
(
ygrad
,
y
.
broadcastable
)
ygrad
=
specify_broadcastable
(
ygrad
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
y
.
type
.
broadcastable
)
if
b
)
)
rval
=
xgrad
,
ygrad
rval
=
xgrad
,
ygrad
...
@@ -2178,7 +2181,11 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
...
@@ -2178,7 +2181,11 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
out
=
out_reshaped
.
reshape
(
outshape
,
outndim
)
out
=
out_reshaped
.
reshape
(
outshape
,
outndim
)
# Make sure the broadcastable pattern of the result is correct,
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
# since some shape information can be lost in the reshapes.
return
patternbroadcast
(
out
,
outbcast
)
if
out
.
type
.
broadcastable
!=
outbcast
:
out
=
specify_broadcastable
(
out
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
outbcast
)
if
b
)
)
return
out
# if 'axes' is a list, transpose a and b such that the summed axes of a
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
# are last and the summed axes of b are first.
...
...
aesara/tensor/nnet/batchnorm.py
浏览文件 @
2066065d
...
@@ -12,6 +12,7 @@ from aesara.tensor.basic_opt import register_specialize_device
...
@@ -12,6 +12,7 @@ from aesara.tensor.basic_opt import register_specialize_device
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
mean
,
prod
,
reciprocal
,
sqrt
from
aesara.tensor.math
import
mean
,
prod
,
reciprocal
,
sqrt
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type
import
TensorType
...
@@ -241,8 +242,8 @@ def batch_normalization_train(
...
@@ -241,8 +242,8 @@ def batch_normalization_train(
gamma
=
gamma
.
dimshuffle
(
params_dimshuffle_pattern
)
gamma
=
gamma
.
dimshuffle
(
params_dimshuffle_pattern
)
beta
=
beta
.
dimshuffle
(
params_dimshuffle_pattern
)
beta
=
beta
.
dimshuffle
(
params_dimshuffle_pattern
)
else
:
else
:
gamma
=
at
.
addbroadcast
(
gamma
,
*
axes
)
gamma
=
specify_broadcastable
(
gamma
,
*
axes
)
beta
=
at
.
addbroadcast
(
beta
,
*
axes
)
beta
=
specify_broadcastable
(
beta
,
*
axes
)
batchnorm_op
=
AbstractBatchNormTrain
(
axes
=
axes
)
batchnorm_op
=
AbstractBatchNormTrain
(
axes
=
axes
)
...
@@ -253,8 +254,8 @@ def batch_normalization_train(
...
@@ -253,8 +254,8 @@ def batch_normalization_train(
running_mean
=
running_mean
.
dimshuffle
(
params_dimshuffle_pattern
)
running_mean
=
running_mean
.
dimshuffle
(
params_dimshuffle_pattern
)
running_var
=
running_var
.
dimshuffle
(
params_dimshuffle_pattern
)
running_var
=
running_var
.
dimshuffle
(
params_dimshuffle_pattern
)
else
:
else
:
running_mean
=
at
.
addbroadcast
(
running_mean
,
*
axes
)
running_mean
=
specify_broadcastable
(
running_mean
,
*
axes
)
running_var
=
at
.
addbroadcast
(
running_var
,
*
axes
)
running_var
=
specify_broadcastable
(
running_var
,
*
axes
)
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
=
batchnorm_op
(
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
=
batchnorm_op
(
inputs
,
inputs
,
gamma
,
gamma
,
...
@@ -265,12 +266,14 @@ def batch_normalization_train(
...
@@ -265,12 +266,14 @@ def batch_normalization_train(
running_var
=
running_var
,
running_var
=
running_var
,
)
)
if
new_running_mean
.
broadcastable
!=
running_mean
.
broadcastable
:
if
new_running_mean
.
broadcastable
!=
running_mean
.
broadcastable
:
new_running_mean
=
at
.
patternbroadcast
(
new_running_mean
=
specify_broadcastable
(
new_running_mean
,
running_mean
.
broadcastable
new_running_mean
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
running_mean
.
type
.
broadcastable
)
if
b
),
)
)
if
new_running_var
.
broadcastable
!=
running_var
.
broadcastable
:
if
new_running_var
.
broadcastable
!=
running_var
.
broadcastable
:
new_running_var
=
at
.
patternbroadcast
(
new_running_var
=
specify_broadcastable
(
new_running_var
,
running_var
.
broadcastable
new_running_var
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
running_var
.
type
.
broadcastable
)
if
b
),
)
)
results
=
(
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
)
results
=
(
out
,
mean
,
invstd
,
new_running_mean
,
new_running_var
)
else
:
else
:
...
@@ -331,7 +334,7 @@ def batch_normalization_test(
...
@@ -331,7 +334,7 @@ def batch_normalization_test(
axes = (0,)
axes = (0,)
# for spatial normalization
# for spatial normalization
axes = (0,) + tuple(range(2, inputs.ndim))
axes = (0,) + tuple(range(2, inputs.ndim))
gamma, beta, mean, var = (at.
addbroadcast
(t, *axes)
gamma, beta, mean, var = (at.
specify_broadcastable
(t, *axes)
for t in (gamma, beta, mean, var))
for t in (gamma, beta, mean, var))
out = (inputs - mean) * gamma / at.sqrt(var + epsilon) + beta
out = (inputs - mean) * gamma / at.sqrt(var + epsilon) + beta
"""
"""
...
@@ -377,10 +380,10 @@ def batch_normalization_test(
...
@@ -377,10 +380,10 @@ def batch_normalization_test(
mean
=
mean
.
dimshuffle
(
params_dimshuffle_pattern
)
mean
=
mean
.
dimshuffle
(
params_dimshuffle_pattern
)
var
=
var
.
dimshuffle
(
params_dimshuffle_pattern
)
var
=
var
.
dimshuffle
(
params_dimshuffle_pattern
)
else
:
else
:
gamma
=
at
.
addbroadcast
(
gamma
,
*
axes
)
gamma
=
specify_broadcastable
(
gamma
,
*
axes
)
beta
=
at
.
addbroadcast
(
beta
,
*
axes
)
beta
=
specify_broadcastable
(
beta
,
*
axes
)
mean
=
at
.
addbroadcast
(
mean
,
*
axes
)
mean
=
specify_broadcastable
(
mean
,
*
axes
)
var
=
at
.
addbroadcast
(
var
,
*
axes
)
var
=
specify_broadcastable
(
var
,
*
axes
)
batchnorm_op
=
AbstractBatchNormInference
(
axes
=
axes
)
batchnorm_op
=
AbstractBatchNormInference
(
axes
=
axes
)
return
batchnorm_op
(
inputs
,
gamma
,
beta
,
mean
,
var
,
epsilon
=
epsilon
)
return
batchnorm_op
(
inputs
,
gamma
,
beta
,
mean
,
var
,
epsilon
=
epsilon
)
...
@@ -609,7 +612,7 @@ class AbstractBatchNormInference(Op):
...
@@ -609,7 +612,7 @@ class AbstractBatchNormInference(Op):
)
)
scale
,
bias
,
est_mean
,
est_var
=
(
scale
,
bias
,
est_mean
,
est_var
=
(
at
.
addbroadcast
(
t
,
*
axes
)
for
t
in
(
scale
,
bias
,
est_mean
,
est_var
)
specify_broadcastable
(
t
,
*
axes
)
for
t
in
(
scale
,
bias
,
est_mean
,
est_var
)
)
)
# define helper expressions
# define helper expressions
...
...
aesara/tensor/nnet/conv.py
浏览文件 @
2066065d
...
@@ -26,13 +26,10 @@ import aesara
...
@@ -26,13 +26,10 @@ import aesara
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.link.c.op
import
OpenMPOp
from
aesara.link.c.op
import
OpenMPOp
from
aesara.tensor
import
blas
from
aesara.tensor
import
blas
from
aesara.tensor.basic
import
(
from
aesara.tensor.basic
import
as_tensor_variable
,
get_scalar_constant_value
as_tensor_variable
,
get_scalar_constant_value
,
patternbroadcast
,
)
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.nnet.abstract_conv
import
get_conv_output_shape
,
get_conv_shape_1axis
from
aesara.tensor.nnet.abstract_conv
import
get_conv_output_shape
,
get_conv_shape_1axis
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
discrete_dtypes
,
tensor
from
aesara.tensor.type
import
discrete_dtypes
,
tensor
...
@@ -1103,8 +1100,14 @@ class ConvOp(OpenMPOp):
...
@@ -1103,8 +1100,14 @@ class ConvOp(OpenMPOp):
# din and dw should have the same broadcasting pattern as the
# din and dw should have the same broadcasting pattern as the
# parameters they are the gradient of (resp. inputs and kerns).
# parameters they are the gradient of (resp. inputs and kerns).
din
=
patternbroadcast
(
din
,
inputs
.
broadcastable
)
if
din
.
type
.
broadcastable
!=
inputs
.
type
.
broadcastable
:
dw
=
patternbroadcast
(
dw
,
kerns
.
broadcastable
)
din
=
specify_broadcastable
(
din
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
inputs
.
type
.
broadcastable
)
if
b
)
)
if
dw
.
type
.
broadcastable
!=
kerns
.
type
.
broadcastable
:
dw
=
specify_broadcastable
(
dw
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
kerns
.
type
.
broadcastable
)
if
b
)
)
return
[
din
,
dw
]
return
[
din
,
dw
]
def
c_headers
(
self
,
**
kwargs
):
def
c_headers
(
self
,
**
kwargs
):
...
...
aesara/tensor/shape.py
浏览文件 @
2066065d
...
@@ -12,7 +12,7 @@ from aesara.link.c.op import COp
...
@@ -12,7 +12,7 @@ from aesara.link.c.op import COp
from
aesara.link.c.params_type
import
ParamsType
from
aesara.link.c.params_type
import
ParamsType
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.scalar
import
int32
from
aesara.scalar
import
int32
from
aesara.tensor
import
_get_vector_length
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
get_vector_length
from
aesara.tensor
import
get_vector_length
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
...
@@ -891,3 +891,38 @@ register_shape_i_c_code(
...
@@ -891,3 +891,38 @@ register_shape_i_c_code(
"""
,
"""
,
version
=
3
,
version
=
3
,
)
)
def
specify_broadcastable
(
x
,
*
axes
):
"""
Specify the input as being broadcastable in the specified axes.
For example, specify_broadcastable(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x
=
as_tensor_variable
(
x
)
if
not
axes
:
return
x
if
max
(
axes
)
>=
x
.
type
.
ndim
:
raise
ValueError
(
"Trying to specify broadcastable of non-existent dimension"
)
shape_info
=
[
1
if
i
in
axes
else
None
for
i
in
range
(
len
(
x
.
type
.
shape
))]
return
specify_shape
(
x
,
shape_info
)
aesara/tensor/subtensor.py
浏览文件 @
2066065d
...
@@ -20,7 +20,7 @@ from aesara.misc.safe_asarray import _asarray
...
@@ -20,7 +20,7 @@ from aesara.misc.safe_asarray import _asarray
from
aesara.printing
import
Printer
,
pprint
,
set_precedence
from
aesara.printing
import
Printer
,
pprint
,
set_precedence
from
aesara.scalar.basic
import
ScalarConstant
from
aesara.scalar.basic
import
ScalarConstant
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
aesara.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
aesara.tensor.basic
import
a
ddbroadcast
,
a
lloc
,
get_scalar_constant_value
from
aesara.tensor.basic
import
alloc
,
get_scalar_constant_value
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.exceptions
import
(
from
aesara.tensor.exceptions
import
(
AdvancedIndexingError
,
AdvancedIndexingError
,
...
@@ -28,7 +28,7 @@ from aesara.tensor.exceptions import (
...
@@ -28,7 +28,7 @@ from aesara.tensor.exceptions import (
ShapeError
,
ShapeError
,
)
)
from
aesara.tensor.math
import
clip
from
aesara.tensor.math
import
clip
from
aesara.tensor.shape
import
Reshape
from
aesara.tensor.shape
import
Reshape
,
specify_broadcastable
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
TensorType
,
TensorType
,
bscalar
,
bscalar
,
...
@@ -1322,8 +1322,8 @@ def inc_subtensor(
...
@@ -1322,8 +1322,8 @@ def inc_subtensor(
# It is acceptable to try to increment a subtensor with a
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# on that dimension. However, its length must then be 1.
# We insert a
Rebroadcast
Op to make sure it is the case.
# We insert a
SpecifyShape
Op to make sure it is the case.
y
=
addbroadcast
(
y
,
dim
)
y
=
specify_broadcastable
(
y
,
dim
)
if
not
x
.
owner
:
if
not
x
.
owner
:
raise
TypeError
(
"x must be the result of a subtensor operation"
)
raise
TypeError
(
"x must be the result of a subtensor operation"
)
...
...
tests/tensor/nnet/test_batchnorm.py
浏览文件 @
2066065d
...
@@ -8,6 +8,7 @@ import aesara.tensor as at
...
@@ -8,6 +8,7 @@ import aesara.tensor as at
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.nnet
import
batchnorm
from
aesara.tensor.nnet
import
batchnorm
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
TensorType
,
TensorType
,
matrix
,
matrix
,
...
@@ -219,8 +220,8 @@ def test_batch_normalization_train():
...
@@ -219,8 +220,8 @@ def test_batch_normalization_train():
x_mean2
=
x
.
mean
(
axis
=
axes2
,
keepdims
=
True
)
x_mean2
=
x
.
mean
(
axis
=
axes2
,
keepdims
=
True
)
x_var2
=
x
.
var
(
axis
=
axes2
,
keepdims
=
True
)
x_var2
=
x
.
var
(
axis
=
axes2
,
keepdims
=
True
)
x_invstd2
=
at
.
reciprocal
(
at
.
sqrt
(
x_var2
+
eps
))
x_invstd2
=
at
.
reciprocal
(
at
.
sqrt
(
x_var2
+
eps
))
scale2
=
at
.
addbroadcast
(
scale
,
*
axes2
)
scale2
=
specify_broadcastable
(
scale
,
*
axes2
)
bias2
=
at
.
addbroadcast
(
bias
,
*
axes2
)
bias2
=
specify_broadcastable
(
bias
,
*
axes2
)
out2
=
(
x
-
x_mean2
)
*
(
scale2
*
x_invstd2
)
+
bias2
out2
=
(
x
-
x_mean2
)
*
(
scale2
*
x_invstd2
)
+
bias2
m
=
at
.
cast
(
at
.
prod
(
x
.
shape
)
/
at
.
prod
(
scale
.
shape
),
aesara
.
config
.
floatX
)
m
=
at
.
cast
(
at
.
prod
(
x
.
shape
)
/
at
.
prod
(
scale
.
shape
),
aesara
.
config
.
floatX
)
out_running_mean2
=
(
out_running_mean2
=
(
...
@@ -597,7 +598,7 @@ def test_batch_normalization_test():
...
@@ -597,7 +598,7 @@ def test_batch_normalization_test():
else
:
else
:
axes2
=
axes
axes2
=
axes
scale2
,
bias2
,
mean2
,
var2
=
(
scale2
,
bias2
,
mean2
,
var2
=
(
at
.
addbroadcast
(
t
,
*
axes2
)
for
t
in
(
scale
,
bias
,
mean
,
var
)
specify_broadcastable
(
t
,
*
axes2
)
for
t
in
(
scale
,
bias
,
mean
,
var
)
)
)
out2
=
(
x
-
mean2
)
*
(
scale2
/
at
.
sqrt
(
var2
+
eps
))
+
bias2
out2
=
(
x
-
mean2
)
*
(
scale2
/
at
.
sqrt
(
var2
+
eps
))
+
bias2
# backward pass
# backward pass
...
...
tests/tensor/test_basic.py
浏览文件 @
2066065d
...
@@ -39,7 +39,6 @@ from aesara.tensor.basic import (
...
@@ -39,7 +39,6 @@ from aesara.tensor.basic import (
Split
,
Split
,
TensorFromScalar
,
TensorFromScalar
,
Tri
,
Tri
,
addbroadcast
,
alloc
,
alloc
,
arange
,
arange
,
as_tensor_variable
,
as_tensor_variable
,
...
@@ -69,7 +68,6 @@ from aesara.tensor.basic import (
...
@@ -69,7 +68,6 @@ from aesara.tensor.basic import (
nonzero_values
,
nonzero_values
,
ogrid
,
ogrid
,
ones_like
,
ones_like
,
patternbroadcast
,
permute_row_elements
,
permute_row_elements
,
roll
,
roll
,
scalar_from_tensor
,
scalar_from_tensor
,
...
@@ -3226,20 +3224,7 @@ class TestLongTensor:
...
@@ -3226,20 +3224,7 @@ class TestLongTensor:
class
TestBroadcast
:
class
TestBroadcast
:
def
test_addbroadcast_validation
(
self
):
def
test_unbroadcast
(
self
):
x
=
as_tensor_variable
(
np
.
zeros
((
2
,
3
)))
with
pytest
.
raises
(
ValueError
,
match
=
".*pattern does not.*"
):
addbroadcast
(
x
,
4
)
def
test_broadcast_bigdim
(
self
):
def
f
():
x
=
matrix
()
addbroadcast
(
x
,
2
)
with
pytest
.
raises
(
ValueError
):
f
()
def
test_unbroadcast_addbroadcast
(
self
):
# test that the unbroadcast fct don't insert not needed broadcast
# test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op
# and fuse consecutive Rebroadcast op
...
@@ -3249,26 +3234,12 @@ class TestBroadcast:
...
@@ -3249,26 +3234,12 @@ class TestBroadcast:
assert
unbroadcast
(
x
,
1
,
0
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
x
assert
unbroadcast
(
x
,
0
,
1
)
is
x
assert
unbroadcast
(
x
,
0
,
1
)
is
x
assert
addbroadcast
(
x
,
0
)
is
not
x
assert
addbroadcast
(
x
,
1
)
is
not
x
assert
addbroadcast
(
x
,
1
,
0
)
.
owner
.
inputs
[
0
]
is
x
assert
unbroadcast
(
addbroadcast
(
x
,
0
),
0
)
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
0
),
0
)
is
not
x
x
=
row
()
x
=
row
()
assert
unbroadcast
(
x
,
0
)
is
not
x
assert
unbroadcast
(
x
,
0
)
is
not
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
not
x
assert
unbroadcast
(
x
,
1
,
0
)
is
not
x
assert
unbroadcast
(
x
,
0
,
1
)
is
not
x
assert
unbroadcast
(
x
,
0
,
1
)
is
not
x
assert
addbroadcast
(
x
,
0
)
is
x
assert
addbroadcast
(
x
,
1
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
x
,
1
,
0
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
x
,
0
,
1
)
.
owner
.
inputs
[
0
]
is
x
assert
unbroadcast
(
addbroadcast
(
x
,
1
),
1
)
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
1
),
1
)
is
not
x
# The first broadcast is remove the broadcast, so the second
# The first broadcast is remove the broadcast, so the second
# should not make one
# should not make one
assert
unbroadcast
(
unbroadcast
(
x
,
0
),
0
)
.
owner
.
inputs
[
0
]
is
x
assert
unbroadcast
(
unbroadcast
(
x
,
0
),
0
)
.
owner
.
inputs
[
0
]
is
x
...
@@ -3276,29 +3247,8 @@ class TestBroadcast:
...
@@ -3276,29 +3247,8 @@ class TestBroadcast:
# Test that consecutive Rebroadcast op are fused
# Test that consecutive Rebroadcast op are fused
x
=
TensorType
(
dtype
=
"float64"
,
shape
=
(
True
,
True
))()
x
=
TensorType
(
dtype
=
"float64"
,
shape
=
(
True
,
True
))()
assert
unbroadcast
(
unbroadcast
(
x
,
1
),
0
)
.
owner
.
inputs
[
0
]
is
x
assert
unbroadcast
(
unbroadcast
(
x
,
1
),
0
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
1
),
0
)
.
owner
.
inputs
[
0
]
is
x
assert
addbroadcast
(
unbroadcast
(
x
,
0
),
0
)
is
x
def
test_patternbroadcast
(
self
):
# Test that patternbroadcast with an empty broadcasting pattern works
x
=
scalar
(
"x"
)
m
=
matrix
(
"m"
)
s
=
patternbroadcast
(
m
,
x
.
broadcastable
)
assert
s
is
m
x2
=
patternbroadcast
(
x
,
x
.
broadcastable
)
assert
x2
is
x
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
x
=
matrix
()
y
=
addbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
1
,
5
),
dtype
=
config
.
floatX
))
==
[
1
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
MakeVector
)
x
=
matrix
()
x
=
matrix
()
y
=
unbroadcast
(
x
,
0
)
y
=
unbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
2066065d
...
@@ -1911,18 +1911,6 @@ class TestRebroadcast:
...
@@ -1911,18 +1911,6 @@ class TestRebroadcast:
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
def
test_rebroadcast_rebroadcast
(
self
):
mode
=
get_default_mode
()
.
including
(
"canonicalize"
)
m
=
matrix
()
s
=
at
.
addbroadcast
(
m
,
0
,
1
)
v
=
at
.
unbroadcast
(
s
,
1
)
f
=
function
([
m
],
v
,
mode
=
mode
)
f
([[
76
]])
e
=
f
.
maker
.
fgraph
.
toposort
()
rebroadcast_nodes
=
[
n
for
n
in
e
if
isinstance
(
n
.
op
,
Rebroadcast
)]
assert
len
(
rebroadcast_nodes
)
==
1
assert
rebroadcast_nodes
[
0
]
.
op
.
axis
==
{
0
:
True
}
class
TestUselessElemwise
:
class
TestUselessElemwise
:
def
setup_method
(
self
):
def
setup_method
(
self
):
...
...
tests/tensor/test_math.py
浏览文件 @
2066065d
...
@@ -1918,6 +1918,9 @@ class TestDot:
...
@@ -1918,6 +1918,9 @@ class TestDot:
# These examples should all work. All dimensions of all results have
# These examples should all work. All dimensions of all results have
# size 1.
# size 1.
#
#
def
is_super_shape
(
var1
,
var2
):
# Check that var1.type is a superset of var2.type, ignoring dtype
return
var1
.
type
.
is_super
(
var2
.
type
.
clone
(
dtype
=
var1
.
type
.
dtype
))
for
dtype0
in
(
"float32"
,
"float64"
,
"complex64"
):
for
dtype0
in
(
"float32"
,
"float64"
,
"complex64"
):
for
dtype1
in
(
"float32"
,
"complex64"
,
"complex128"
):
for
dtype1
in
(
"float32"
,
"complex64"
,
"complex128"
):
...
@@ -1944,9 +1947,9 @@ class TestDot:
...
@@ -1944,9 +1947,9 @@ class TestDot:
if
dtype0
.
startswith
(
"float"
)
and
dtype1
.
startswith
(
"float"
):
if
dtype0
.
startswith
(
"float"
)
and
dtype1
.
startswith
(
"float"
):
g
=
grad
(
z
.
sum
(),
x
)
g
=
grad
(
z
.
sum
(),
x
)
assert
g
.
broadcastable
==
x
.
broadcastable
assert
is_super_shape
(
x
,
g
)
g
=
grad
(
z
.
sum
(),
y
)
g
=
grad
(
z
.
sum
(),
y
)
assert
g
.
broadcastable
==
y
.
broadcastable
assert
is_super_shape
(
y
,
g
)
class
TestTensordot
:
class
TestTensordot
:
...
...
tests/tensor/test_opt_uncanonicalize.py
浏览文件 @
2066065d
...
@@ -19,7 +19,7 @@ from aesara.tensor.opt_uncanonicalize import (
...
@@ -19,7 +19,7 @@ from aesara.tensor.opt_uncanonicalize import (
local_dimshuffle_subtensor
,
local_dimshuffle_subtensor
,
local_reshape_dimshuffle
,
local_reshape_dimshuffle
,
)
)
from
aesara.tensor.shape
import
reshape
from
aesara.tensor.shape
import
reshape
,
specify_shape
from
aesara.tensor.type
import
dtensor4
,
iscalar
,
matrix
,
tensor
,
vector
from
aesara.tensor.type
import
dtensor4
,
iscalar
,
matrix
,
tensor
,
vector
from
tests.link.test_link
import
make_function
from
tests.link.test_link
import
make_function
...
@@ -179,7 +179,7 @@ def test_local_dimshuffle_subtensor():
...
@@ -179,7 +179,7 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor
=
out2in
(
local_dimshuffle_subtensor
)
dimshuffle_subtensor
=
out2in
(
local_dimshuffle_subtensor
)
x
=
dtensor4
(
"x"
)
x
=
dtensor4
(
"x"
)
x
=
at
.
patternbroadcast
(
x
,
(
False
,
True
,
False
,
Fals
e
))
x
=
specify_shape
(
x
,
(
None
,
1
,
None
,
Non
e
))
i
=
iscalar
(
"i"
)
i
=
iscalar
(
"i"
)
out
=
x
[:,
:,
10
:
30
,
::
i
]
.
dimshuffle
(
0
,
2
,
3
)
out
=
x
[:,
:,
10
:
30
,
::
i
]
.
dimshuffle
(
0
,
2
,
3
)
...
@@ -213,7 +213,7 @@ def test_local_dimshuffle_subtensor():
...
@@ -213,7 +213,7 @@ def test_local_dimshuffle_subtensor():
# Test a corner case that had Aesara return a bug.
# Test a corner case that had Aesara return a bug.
x
=
dtensor4
(
"x"
)
x
=
dtensor4
(
"x"
)
x
=
at
.
patternbroadcast
(
x
,
(
False
,
True
,
False
,
Fals
e
))
x
=
specify_shape
(
x
,
(
None
,
1
,
None
,
Non
e
))
assert
x
[:,
:,
0
:
3
,
::
-
1
]
.
dimshuffle
(
0
,
2
,
3
)
.
eval
(
assert
x
[:,
:,
0
:
3
,
::
-
1
]
.
dimshuffle
(
0
,
2
,
3
)
.
eval
(
{
x
:
np
.
ones
((
5
,
1
,
6
,
7
))}
{
x
:
np
.
ones
((
5
,
1
,
6
,
7
))}
...
...
tests/tensor/test_shape.py
浏览文件 @
2066065d
...
@@ -9,7 +9,7 @@ from aesara.graph.basic import Variable
...
@@ -9,7 +9,7 @@ from aesara.graph.basic import Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.tensor
import
as_tensor_variable
,
get_vector_length
from
aesara.tensor
import
as_tensor_variable
,
get_vector_length
,
row
from
aesara.tensor.basic
import
MakeVector
,
constant
from
aesara.tensor.basic
import
MakeVector
,
constant
from
aesara.tensor.basic_opt
import
ShapeFeature
from
aesara.tensor.basic_opt
import
ShapeFeature
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
...
@@ -21,6 +21,7 @@ from aesara.tensor.shape import (
...
@@ -21,6 +21,7 @@ from aesara.tensor.shape import (
reshape
,
reshape
,
shape
,
shape
,
shape_i
,
shape_i
,
specify_broadcastable
,
specify_shape
,
specify_shape
,
)
)
from
aesara.tensor.subtensor
import
Subtensor
from
aesara.tensor.subtensor
import
Subtensor
...
@@ -518,6 +519,23 @@ class TestSpecifyShape(utt.InferShapeTester):
...
@@ -518,6 +519,23 @@ class TestSpecifyShape(utt.InferShapeTester):
assert
isinstance
(
z_grad
.
owner
.
op
,
SpecifyShape
)
assert
isinstance
(
z_grad
.
owner
.
op
,
SpecifyShape
)
class
TestSpecifyBroadcastable
:
def
test_basic
(
self
):
x
=
matrix
()
assert
specify_broadcastable
(
x
,
0
)
.
type
.
shape
==
(
1
,
None
)
assert
specify_broadcastable
(
x
,
1
)
.
type
.
shape
==
(
None
,
1
)
assert
specify_broadcastable
(
x
,
0
,
1
)
.
type
.
shape
==
(
1
,
1
)
x
=
row
()
assert
specify_broadcastable
(
x
,
0
)
is
x
assert
specify_broadcastable
(
x
,
1
)
is
not
x
def
test_validation
(
self
):
x
=
matrix
()
with
pytest
.
raises
(
ValueError
,
match
=
"^Trying to specify broadcastable of*"
):
specify_broadcastable
(
x
,
2
)
class
TestRopLop
(
RopLopChecker
):
class
TestRopLop
(
RopLopChecker
):
def
test_shape
(
self
):
def
test_shape
(
self
):
self
.
check_nondiff_rop
(
self
.
x
.
shape
[
0
])
self
.
check_nondiff_rop
(
self
.
x
.
shape
[
0
])
...
...
tests/tensor/test_subtensor.py
浏览文件 @
2066065d
...
@@ -1501,11 +1501,11 @@ class TestIncSubtensor:
...
@@ -1501,11 +1501,11 @@ class TestIncSubtensor:
# This one should work
# This one should work
f
(
rng_randX
(
3
,
1
),
rng_randX
(
1
))
f
(
rng_randX
(
3
,
1
),
rng_randX
(
1
))
# These ones should not
# These ones should not
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
f
(
rng_randX
(
3
,
1
),
rng_randX
(
2
))
f
(
rng_randX
(
3
,
1
),
rng_randX
(
2
))
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
f
(
rng_randX
(
3
,
1
),
rng_randX
(
3
))
f
(
rng_randX
(
3
,
1
),
rng_randX
(
3
))
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
Assertion
Error
):
f
(
rng_randX
(
3
,
1
),
rng_randX
(
0
))
f
(
rng_randX
(
3
,
1
),
rng_randX
(
0
))
def
test_simple_3d
(
self
):
def
test_simple_3d
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论