Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7f8af9bc
提交
7f8af9bc
authored
5月 09, 2022
作者:
Ricardo
提交者:
Brandon T. Willard
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deprecate remaining uses of Rebroadcast in favor of Unbroadcast
上级
ac52d689
隐藏空白字符变更
内嵌
并排
正在显示
18 个修改的文件
包含
337 行增加
和
538 行删除
+337
-538
__init__.py
aesara/__init__.py
+1
-1
pfunc.py
aesara/compile/function/pfunc.py
+2
-2
ifelse.py
aesara/ifelse.py
+2
-3
dispatch.py
aesara/link/jax/dispatch.py
+5
-14
tensor_basic.py
aesara/link/numba/dispatch/tensor_basic.py
+5
-14
basic.py
aesara/scan/basic.py
+4
-4
basic.py
aesara/tensor/basic.py
+5
-214
basic_opt.py
aesara/tensor/basic_opt.py
+33
-76
shape.py
aesara/tensor/shape.py
+105
-0
subtensor_opt.py
aesara/tensor/subtensor_opt.py
+12
-11
test_jax.py
tests/link/test_jax.py
+2
-11
test_numba.py
tests/link/test_numba.py
+12
-33
test_printing.py
tests/scan/test_printing.py
+15
-15
test_basic.py
tests/tensor/test_basic.py
+1
-78
test_basic_opt.py
tests/tensor/test_basic_opt.py
+41
-28
test_shape.py
tests/tensor/test_shape.py
+63
-0
test_subtensor_opt.py
tests/tensor/test_subtensor_opt.py
+26
-32
test_rop.py
tests/test_rop.py
+3
-2
没有找到文件。
aesara/__init__.py
浏览文件 @
7f8af9bc
...
@@ -147,7 +147,7 @@ from aesara.updates import OrderedUpdates
...
@@ -147,7 +147,7 @@ from aesara.updates import OrderedUpdates
def
get_scalar_constant_value
(
v
):
def
get_scalar_constant_value
(
v
):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs,
rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs,
cast, etc.
this function digs through them.
this function digs through them.
If ``aesara.sparse`` is also there, we will look over CSM `Op`.
If ``aesara.sparse`` is also there, we will look over CSM `Op`.
...
...
aesara/compile/function/pfunc.py
浏览文件 @
7f8af9bc
...
@@ -204,8 +204,8 @@ def rebuild_collect_shared(
...
@@ -204,8 +204,8 @@ def rebuild_collect_shared(
err_sug
=
(
err_sug
=
(
"If the difference is related to the broadcast pattern,"
"If the difference is related to the broadcast pattern,"
" you can call the"
" you can call the"
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])"
" tensor.
shape.
unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to
remove
broadcastable dimensions."
" function to
mask
broadcastable dimensions."
)
)
raise
TypeError
(
err_msg
,
err_sug
)
raise
TypeError
(
err_msg
,
err_sug
)
...
...
aesara/ifelse.py
浏览文件 @
7f8af9bc
...
@@ -23,8 +23,7 @@ from aesara.configdefaults import config
...
@@ -23,8 +23,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.tensor
import
basic
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
__docformat__
=
"restructedtext en"
__docformat__
=
"restructedtext en"
...
@@ -451,7 +450,7 @@ acceptable_ops = (
...
@@ -451,7 +450,7 @@ acceptable_ops = (
Shape
,
Shape
,
SpecifyShape
,
SpecifyShape
,
Reshape
,
Reshape
,
basic
.
Re
broadcast
,
Un
broadcast
,
at
.
math
.
Dot
,
at
.
math
.
Dot
,
at
.
math
.
MaxAndArgmax
,
at
.
math
.
MaxAndArgmax
,
at
.
subtensor
.
Subtensor
,
at
.
subtensor
.
Subtensor
,
...
...
aesara/link/jax/dispatch.py
浏览文件 @
7f8af9bc
...
@@ -29,7 +29,6 @@ from aesara.tensor.basic import (
...
@@ -29,7 +29,6 @@ from aesara.tensor.basic import (
Eye
,
Eye
,
Join
,
Join
,
MakeVector
,
MakeVector
,
Rebroadcast
,
ScalarFromTensor
,
ScalarFromTensor
,
TensorFromScalar
,
TensorFromScalar
,
)
)
...
@@ -50,7 +49,7 @@ from aesara.tensor.math import Dot, MaxAndArgmax
...
@@ -50,7 +49,7 @@ from aesara.tensor.math import Dot, MaxAndArgmax
from
aesara.tensor.nlinalg
import
SVD
,
Det
,
Eig
,
Eigh
,
MatrixInverse
,
QRFull
from
aesara.tensor.nlinalg
import
SVD
,
Det
,
Eig
,
Eigh
,
MatrixInverse
,
QRFull
from
aesara.tensor.nnet.basic
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
aesara.tensor.nnet.basic
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
aesara.tensor.random.op
import
RandomVariable
from
aesara.tensor.random.op
import
RandomVariable
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
from
aesara.tensor.slinalg
import
Cholesky
,
Solve
,
SolveTriangular
from
aesara.tensor.slinalg
import
Cholesky
,
Solve
,
SolveTriangular
from
aesara.tensor.subtensor
import
(
from
aesara.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
...
@@ -347,20 +346,12 @@ def jax_funcify_SpecifyShape(op, **kwargs):
...
@@ -347,20 +346,12 @@ def jax_funcify_SpecifyShape(op, **kwargs):
return
specifyshape
return
specifyshape
@jax_funcify.register
(
Rebroadcast
)
@jax_funcify.register
(
Unbroadcast
)
def
jax_funcify_Rebroadcast
(
op
,
**
kwargs
):
def
jax_funcify_Unbroadcast
(
op
,
**
kwargs
):
op_axis
=
op
.
axis
def
unbroadcast
(
x
):
def
rebroadcast
(
x
):
for
axis
,
value
in
op_axis
.
items
():
if
value
and
x
.
shape
[
axis
]
!=
1
:
raise
ValueError
(
"Dimension
%
s in Rebroadcast's input was"
" supposed to be 1 (got
%
s instead)"
%
(
axis
,
x
.
shape
[
axis
])
)
return
x
return
x
return
re
broadcast
return
un
broadcast
@jax_funcify.register
(
ViewOp
)
@jax_funcify.register
(
ViewOp
)
...
...
aesara/link/numba/dispatch/tensor_basic.py
浏览文件 @
7f8af9bc
...
@@ -14,10 +14,10 @@ from aesara.tensor.basic import (
...
@@ -14,10 +14,10 @@ from aesara.tensor.basic import (
Eye
,
Eye
,
Join
,
Join
,
MakeVector
,
MakeVector
,
Rebroadcast
,
ScalarFromTensor
,
ScalarFromTensor
,
TensorFromScalar
,
TensorFromScalar
,
)
)
from
aesara.tensor.shape
import
Unbroadcast
@numba_funcify.register
(
AllocEmpty
)
@numba_funcify.register
(
AllocEmpty
)
...
@@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
...
@@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
return
numba_basic
.
numba_njit
(
makevector_fn
)
return
numba_basic
.
numba_njit
(
makevector_fn
)
@numba_funcify.register
(
Rebroadcast
)
@numba_funcify.register
(
Unbroadcast
)
def
numba_funcify_Rebroadcast
(
op
,
**
kwargs
):
def
numba_funcify_Unbroadcast
(
op
,
**
kwargs
):
# Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis
=
tuple
((
axis
,
int
(
value
))
for
axis
,
value
in
op
.
axis
.
items
())
@numba_basic.numba_njit
@numba_basic.numba_njit
def
rebroadcast
(
x
):
def
unbroadcast
(
x
):
for
axis
,
value
in
op_axis
:
if
value
and
x
.
shape
[
axis
]
!=
1
:
raise
ValueError
(
(
"Dimension in Rebroadcast's input was supposed to be 1"
)
)
return
x
return
x
return
re
broadcast
return
un
broadcast
@numba_funcify.register
(
TensorFromScalar
)
@numba_funcify.register
(
TensorFromScalar
)
...
...
aesara/scan/basic.py
浏览文件 @
7f8af9bc
...
@@ -14,7 +14,7 @@ from aesara.scan.utils import expand_empty, safe_new, until
...
@@ -14,7 +14,7 @@ from aesara.scan.utils import expand_empty, safe_new, until
from
aesara.tensor.basic
import
get_scalar_constant_value
from
aesara.tensor.basic
import
get_scalar_constant_value
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
minimum
from
aesara.tensor.math
import
minimum
from
aesara.tensor.shape
import
shape_padleft
from
aesara.tensor.shape
import
shape_padleft
,
unbroadcast
from
aesara.tensor.type
import
TensorType
,
integer_dtypes
from
aesara.tensor.type
import
TensorType
,
integer_dtypes
from
aesara.updates
import
OrderedUpdates
from
aesara.updates
import
OrderedUpdates
...
@@ -751,7 +751,7 @@ def scan(
...
@@ -751,7 +751,7 @@ def scan(
# defined in scan utils
# defined in scan utils
sit_sot_scan_inputs
.
append
(
sit_sot_scan_inputs
.
append
(
expand_empty
(
expand_empty
(
at
.
unbroadcast
(
shape_padleft
(
actual_arg
),
0
),
unbroadcast
(
shape_padleft
(
actual_arg
),
0
),
actual_n_steps
,
actual_n_steps
,
)
)
)
)
...
@@ -881,7 +881,7 @@ def scan(
...
@@ -881,7 +881,7 @@ def scan(
# this will represent only a slice and it will have one
# this will represent only a slice and it will have one
# dimension less.
# dimension less.
if
isinstance
(
inner_out
.
type
,
TensorType
)
and
return_steps
.
get
(
pos
,
0
)
!=
1
:
if
isinstance
(
inner_out
.
type
,
TensorType
)
and
return_steps
.
get
(
pos
,
0
)
!=
1
:
outputs
[
pos
]
=
at
.
unbroadcast
(
shape_padleft
(
inner_out
),
0
)
outputs
[
pos
]
=
unbroadcast
(
shape_padleft
(
inner_out
),
0
)
if
not
return_list
and
len
(
outputs
)
==
1
:
if
not
return_list
and
len
(
outputs
)
==
1
:
outputs
=
outputs
[
0
]
outputs
=
outputs
[
0
]
...
@@ -1010,7 +1010,7 @@ def scan(
...
@@ -1010,7 +1010,7 @@ def scan(
sit_sot_inner_inputs
.
append
(
new_var
)
sit_sot_inner_inputs
.
append
(
new_var
)
sit_sot_scan_inputs
.
append
(
sit_sot_scan_inputs
.
append
(
expand_empty
(
expand_empty
(
at
.
unbroadcast
(
shape_padleft
(
input
.
variable
),
0
),
unbroadcast
(
shape_padleft
(
input
.
variable
),
0
),
actual_n_steps
,
actual_n_steps
,
)
)
)
)
...
...
aesara/tensor/basic.py
浏览文件 @
7f8af9bc
...
@@ -10,7 +10,7 @@ import warnings
...
@@ -10,7 +10,7 @@ import warnings
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
functools
import
partial
from
functools
import
partial
from
numbers
import
Number
from
numbers
import
Number
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
cast
as
type_cast
from
typing
import
cast
as
type_cast
import
numpy
as
np
import
numpy
as
np
...
@@ -44,6 +44,7 @@ from aesara.tensor.exceptions import NotScalarConstantError
...
@@ -44,6 +44,7 @@ from aesara.tensor.exceptions import NotScalarConstantError
from
aesara.tensor.shape
import
(
from
aesara.tensor.shape
import
(
Shape
,
Shape
,
Shape_i
,
Shape_i
,
Unbroadcast
,
shape
,
shape
,
shape_padaxis
,
shape_padaxis
,
shape_padleft
,
shape_padleft
,
...
@@ -254,7 +255,7 @@ def get_scalar_constant_value(
...
@@ -254,7 +255,7 @@ def get_scalar_constant_value(
):
):
"""Return the constant scalar(0-D) value underlying variable `v`.
"""Return the constant scalar(0-D) value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs,
rebroadcasts
,
If `v` is the output of dimshuffles, fills, allocs,
etc
,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them.
and some pattern with Subtensor, this function digs through them.
...
@@ -323,7 +324,7 @@ def get_scalar_constant_value(
...
@@ -323,7 +324,7 @@ def get_scalar_constant_value(
(
(
Alloc
,
Alloc
,
DimShuffle
,
DimShuffle
,
Re
broadcast
,
Un
broadcast
,
# outputguard is only used in debugmode but we
# outputguard is only used in debugmode but we
# keep it here to avoid problems with old pickels.
# keep it here to avoid problems with old pickels.
compile
.
ops
.
OutputGuard
,
compile
.
ops
.
OutputGuard
,
...
@@ -495,7 +496,7 @@ def get_scalar_constant_value(
...
@@ -495,7 +496,7 @@ def get_scalar_constant_value(
gp_broadcastable
=
grandparent
.
type
.
broadcastable
gp_broadcastable
=
grandparent
.
type
.
broadcastable
ndim
=
grandparent
.
type
.
ndim
ndim
=
grandparent
.
type
.
ndim
if
grandparent
.
owner
and
isinstance
(
if
grandparent
.
owner
and
isinstance
(
grandparent
.
owner
.
op
,
Re
broadcast
grandparent
.
owner
.
op
,
Un
broadcast
):
):
ggp_broadcastable
=
grandparent
.
owner
.
inputs
[
0
]
.
broadcastable
ggp_broadcastable
=
grandparent
.
owner
.
inputs
[
0
]
.
broadcastable
l
=
[
l
=
[
...
@@ -616,185 +617,6 @@ class ScalarFromTensor(COp):
...
@@ -616,185 +617,6 @@ class ScalarFromTensor(COp):
scalar_from_tensor
=
ScalarFromTensor
()
scalar_from_tensor
=
ScalarFromTensor
()
class
Rebroadcast
(
COp
):
"""
Change the input's broadcastable fields in some predetermined way.
See Also
--------
unbroadcast <aesara.tensor.unbroadcast>
Notes
-----
Works inplace and works for CudaNdarrayType.
Examples
--------
``Rebroadcast((0, True), (1, False))(x)`` would make `x` broadcastable in
axis 0 and not broadcastable in axis 1.
"""
view_map
=
{
0
:
[
0
]}
_f16_ok
=
True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
:
Dict
=
{}
check_input
=
False
__props__
=
(
"axis"
,)
_f16_ok
=
True
def
__init__
(
self
,
*
axis
):
# Sort them to make sure we merge all possible case.
items
=
sorted
(
axis
)
self
.
axis
=
dict
(
items
)
for
axis
,
broad
in
self
.
axis
.
items
():
if
not
isinstance
(
axis
,
(
np
.
integer
,
int
)):
raise
TypeError
(
f
"Rebroadcast needs integer axes. Got {axis}"
)
if
not
isinstance
(
broad
,
(
np
.
bool_
,
bool
)):
raise
TypeError
(
f
"Rebroadcast needs bool for new broadcast pattern. Got {broad}"
)
def
__hash__
(
self
):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items
=
sorted
(
self
.
axis
.
items
())
return
hash
((
type
(
self
),
tuple
(
items
)))
def
__str__
(
self
):
return
f
"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}"
def
make_node
(
self
,
x
):
if
self
.
axis
.
keys
()
and
(
x
.
ndim
<=
max
(
self
.
axis
.
keys
())):
raise
ValueError
(
"Trying to rebroadcast non-existent dimension"
)
t
=
x
.
type
.
clone
(
shape
=
[
self
.
axis
.
get
(
i
,
b
)
for
i
,
b
in
enumerate
(
x
.
type
.
broadcastable
)]
)
return
Apply
(
self
,
[
x
],
[
t
()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
for
axis
,
value
in
self
.
axis
.
items
():
if
value
and
x
.
shape
[
axis
]
!=
1
:
raise
ValueError
(
f
"Dimension {axis} in Rebroadcast's input was"
f
" supposed to be 1 (got {x.shape[axis]} instead)"
)
out
[
0
]
=
x
def
grad
(
self
,
inp
,
grads
):
(
x
,)
=
inp
(
gz
,)
=
grads
# restore the broadcasting pattern of the input
return
(
Rebroadcast
(
*
[
(
axis
,
x
.
type
.
broadcastable
[
axis
])
for
axis
,
value
in
self
.
axis
.
items
()
]
)(
gz
),
)
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
assert
len
(
ishapes
)
==
1
l
=
[]
one
=
aesara
.
tensor
.
basic
.
constant
(
1
)
for
ax
in
range
(
len
(
ishapes
[
0
])):
if
self
.
axis
.
get
(
ax
,
False
):
l
.
append
(
one
)
else
:
l
.
append
(
ishapes
[
0
][
ax
])
return
[
tuple
(
l
)]
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
return
[
None
]
return
self
(
*
eval_points
,
return_list
=
True
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
(
iname
,)
=
inp
(
oname
,)
=
out
fail
=
sub
[
"fail"
]
itype
=
node
.
inputs
[
0
]
.
type
.
__class__
if
itype
in
self
.
c_code_and_version
:
code
,
version
=
self
.
c_code_and_version
[
itype
]
final_code
=
""
for
axis
,
value
in
self
.
axis
.
items
():
if
value
:
final_code
+=
code
%
locals
()
return
(
final_code
+
f
"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
)
raise
NotImplementedError
()
def
c_code_cache_version
(
self
):
version
=
[]
# If any of the c code is unversioned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])
):
if
not
v
:
warnings
.
warn
(
f
"Type {t} has C code for Rebroadcast, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_rebroadcast_c_code."
,
stacklevel
=
2
,
)
return
()
version
.
append
((
str
(
t
),
v
))
if
version
:
version
.
append
(
1
)
return
tuple
(
version
)
def
register_rebroadcast_c_code
(
typ
,
code
,
version
=
()):
"""
Tell Rebroadcast how to generate C code for an Aesara Type.
typ : Aesara type
It must be the Aesara class itself and not an instance of the class.
code : C code
That checks if the dimension
%(axis)
s is of shape 1 for the Aesara type
'typ'. Use
%(iname)
s and
%(oname)
s for the input and output C variable
names respectively, and
%(axis)
s for the axis that we need to check.
This code is put in a loop for all axes.
version
A number indicating the version of the code, for cache.
"""
Rebroadcast
.
c_code_and_version
[
typ
]
=
(
code
,
version
)
register_rebroadcast_c_code
(
TensorType
,
"""
if(PyArray_DIMS(
%(iname)
s)[
%(axis)
s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension
%(axis)
s in Rebroadcast's input was"
" supposed to be 1 (got
%%
d instead)",
PyArray_DIMS(
%(iname)
s)[
%(axis)
s]);
%(fail)
s
}
"""
,
version
=
1
,
)
# to be removed as we get the epydoc routine-documenting thing going
# to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924
# -JB 20080924
def
_conversion
(
real_value
:
Op
,
name
:
str
)
->
Op
:
def
_conversion
(
real_value
:
Op
,
name
:
str
)
->
Op
:
...
@@ -2254,36 +2076,6 @@ class Split(COp):
...
@@ -2254,36 +2076,6 @@ class Split(COp):
)
)
def
unbroadcast
(
x
,
*
axes
):
"""
Make the input impossible to broadcast in the specified axes.
For example, unbroadcast(x, 0) will make the first dimension
of x not broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be unbroadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is unbroadcastable along the specified dimensions.
"""
x
=
as_tensor_variable
(
x
)
rval
=
Rebroadcast
(
*
[(
axis
,
False
)
for
axis
in
axes
])(
x
)
return
aesara
.
tensor
.
basic_opt
.
apply_rebroadcast_opt
(
rval
)
class
Join
(
COp
):
class
Join
(
COp
):
r"""
r"""
Concatenate multiple `TensorVariable`\s along some axis.
Concatenate multiple `TensorVariable`\s along some axis.
...
@@ -4195,7 +3987,6 @@ __all__ = [
...
@@ -4195,7 +3987,6 @@ __all__ = [
"stack"
,
"stack"
,
"roll"
,
"roll"
,
"join"
,
"join"
,
"unbroadcast"
,
"split"
,
"split"
,
"transpose"
,
"transpose"
,
"extract_constant"
,
"extract_constant"
,
...
...
aesara/tensor/basic_opt.py
浏览文件 @
7f8af9bc
...
@@ -48,7 +48,6 @@ from aesara.tensor.basic import (
...
@@ -48,7 +48,6 @@ from aesara.tensor.basic import (
AllocEmpty
,
AllocEmpty
,
Join
,
Join
,
MakeVector
,
MakeVector
,
Rebroadcast
,
ScalarFromTensor
,
ScalarFromTensor
,
Split
,
Split
,
TensorFromScalar
,
TensorFromScalar
,
...
@@ -77,9 +76,11 @@ from aesara.tensor.shape import (
...
@@ -77,9 +76,11 @@ from aesara.tensor.shape import (
Shape
,
Shape
,
Shape_i
,
Shape_i
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
shape_i
,
shape_i
,
shape_padleft
,
shape_padleft
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
...
@@ -2226,10 +2227,13 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
...
@@ -2226,10 +2227,13 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@local_optimizer
([
Rebroadcast
])
@local_optimizer
([
Unbroadcast
])
def
local_useless_rebroadcast
(
fgraph
,
node
):
def
local_useless_unbroadcast
(
fgraph
,
node
):
"""Remove `Rebroadcast` if it does not actually change the broadcasting pattern."""
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
if
isinstance
(
node
.
op
,
Rebroadcast
):
TODO: Implement equivalent rewrite for SpecifyShape
"""
if
isinstance
(
node
.
op
,
Unbroadcast
):
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
if
x
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
if
x
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
# No broadcastable flag was modified
# No broadcastable flag was modified
...
@@ -2238,15 +2242,12 @@ def local_useless_rebroadcast(fgraph, node):
...
@@ -2238,15 +2242,12 @@ def local_useless_rebroadcast(fgraph, node):
return
[
x
]
return
[
x
]
else
:
else
:
# Keep the flags that modify something
# Keep the flags that modify something
new_axis
=
{}
new_axes
=
tuple
(
ax
for
ax
in
node
.
op
.
axes
if
x
.
type
.
shape
[
ax
]
==
1
)
for
dim
,
bc
in
node
.
op
.
axis
.
items
():
if
new_axes
==
node
.
op
.
axes
:
if
x
.
broadcastable
[
dim
]
!=
bc
:
new_axis
[
dim
]
=
bc
if
new_axis
==
node
.
op
.
axis
:
# All flags are useful
# All flags are useful
return
return
None
else
:
else
:
r
=
Rebroadcast
(
*
new_axis
.
items
())(
x
)
r
=
unbroadcast
(
x
,
*
new_axes
)
# Copy over stacktrace from previous output
# Copy over stacktrace from previous output
copy_stack_trace
(
node
.
outputs
,
r
)
copy_stack_trace
(
node
.
outputs
,
r
)
return
[
r
]
return
[
r
]
...
@@ -2254,93 +2255,49 @@ def local_useless_rebroadcast(fgraph, node):
...
@@ -2254,93 +2255,49 @@ def local_useless_rebroadcast(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@local_optimizer
([
Re
broadcast
])
@local_optimizer
([
Un
broadcast
])
def
local_
re
broadcast_lift
(
fgraph
,
node
):
def
local_
un
broadcast_lift
(
fgraph
,
node
):
"""
"""
Lifts
Rebroadcast
through unary Elemwise operations,
Lifts
`Unbroadcast`
through unary Elemwise operations,
and merges consecutive
Rebroadcast
s.
and merges consecutive
`Unbroadcast`
s.
Rebroadcast(Elemwise(x)) => Elemwise(Re
broadcast(x))
Unbroadcast(Elemwise(x)) => Elemwise(Un
broadcast(x))
Rebroadcast(Rebroadcast(x)) => Re
broadcast(x)
Unbroadcast(Unbroadcast(x)) => Un
broadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
"""
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
Re
broadcast
):
if
not
isinstance
(
op
,
Un
broadcast
):
return
False
return
False
inp
=
node
.
inputs
[
0
]
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
inode
=
inp
.
owner
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
len
(
inode
.
inputs
)
==
1
:
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
len
(
inode
.
inputs
)
==
1
:
# It may happen that `input` has no client because this optimization
# is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function
# compilation phase.
if
len
(
fgraph
.
clients
.
get
(
inp
,
()))
==
1
:
if
len
(
fgraph
.
clients
.
get
(
inp
,
()))
==
1
:
rebroadcasted
=
Rebroadcast
(
*
list
(
op
.
axis
.
items
()))(
inode
.
inputs
[
0
])
unbroadcasted
=
unbroadcast
(
inode
.
inputs
[
0
],
*
op
.
axes
)
# Copy over stacktrace from previous output (after rebroadcasting)
copy_stack_trace
(
node
.
outputs
,
unbroadcasted
)
# to new output, because an error in the new graph right after
# rebroadcasting must have been caused by the previous rebroadcasting.
copy_stack_trace
(
node
.
outputs
,
rebroadcasted
)
rval
=
inode
.
op
.
make_node
(
re
broadcasted
)
.
outputs
rval
=
inode
.
op
.
make_node
(
un
broadcasted
)
.
outputs
# Copy over stacktrace from previous output (after
re
broadcasting)
# Copy over stacktrace from previous output (after
un
broadcasting)
# and input (after elemwise operation) to new output, because an
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# error in the new graph could have been caused by either of the
# two ops.
# two ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
return
rval
if
inode
and
isinstance
(
inode
.
op
,
Rebroadcast
):
# the "axis" specification in the outer Rebroadcast overrides
# the axis of the inner one
axis
=
inode
.
op
.
axis
.
copy
()
axis
.
update
(
op
.
axis
)
iinput
=
inode
.
inputs
[
0
]
rval
=
[
Rebroadcast
(
*
list
(
axis
.
items
()))(
iinput
)]
# Copy over stacktrace from previous output (after second rebroadcast)
if
inode
and
isinstance
(
inode
.
op
,
Unbroadcast
):
# and from previous input (after first rebroadcast op) because an error in
# Merge axis of each unbroadcast
# the new graph could have been caused by either of the two
axis
=
tuple
(
set
(
inode
.
op
.
axes
)
.
union
(
set
(
op
.
axes
)))
# rebroadcast ops.
iinput
=
inode
.
inputs
[
0
]
rval
=
[
unbroadcast
(
iinput
,
*
axis
)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
return
rval
def
apply_rebroadcast_opt
(
rval
):
"""
Apply as many times as required the optimization local_useless_rebroadcast
and local_rebroadcast_lift.
Parameters
----------
rval: a Variable
Returns
-------
A Variable (the same if no optimization can be applied)
"""
fg
=
FunctionGraph
([],
[])
changed
=
True
while
changed
and
rval
.
owner
:
changed
=
False
rval2
=
local_useless_rebroadcast
.
transform
(
fg
,
rval
.
owner
)
if
rval2
:
assert
len
(
rval2
)
==
1
rval
=
rval2
[
0
]
changed
=
True
if
rval
.
owner
:
rval2
=
local_rebroadcast_lift
.
transform
(
fg
,
rval
.
owner
)
if
rval2
:
assert
len
(
rval2
)
==
1
rval
=
rval2
[
0
]
changed
=
True
return
rval
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
...
...
aesara/tensor/shape.py
浏览文件 @
7f8af9bc
...
@@ -926,3 +926,108 @@ def specify_broadcastable(x, *axes):
...
@@ -926,3 +926,108 @@ def specify_broadcastable(x, *axes):
shape_info
=
[
1
if
i
in
axes
else
None
for
i
in
range
(
len
(
x
.
type
.
shape
))]
shape_info
=
[
1
if
i
in
axes
else
None
for
i
in
range
(
len
(
x
.
type
.
shape
))]
return
specify_shape
(
x
,
shape_info
)
return
specify_shape
(
x
,
shape_info
)
class
Unbroadcast
(
COp
):
"""
Mask static broadcastable dimensions of input as `None`
See Also
--------
unbroadcast <aesara.tensor.shape.unbroadcast>
Examples
--------
``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None`
"""
view_map
=
{
0
:
[
0
]}
_f16_ok
=
True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
:
Dict
=
{}
check_input
=
False
__props__
=
(
"axes"
,)
_f16_ok
=
True
def
__init__
(
self
,
*
axis
):
# Sort them to make sure we merge all possible case.
items
=
tuple
(
sorted
(
axis
))
self
.
axes
=
items
for
axis
in
self
.
axes
:
if
not
isinstance
(
axis
,
(
np
.
integer
,
int
)):
raise
TypeError
(
f
"Unbroadcast needs integer axes. Got {axis}"
)
def
__str__
(
self
):
return
f
"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}"
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
if
x
.
type
.
ndim
<=
max
(
self
.
axes
):
raise
ValueError
(
"Trying to unbroadcast of non-existent dimension"
)
shape
=
[
None
if
(
sh
==
1
and
i
in
self
.
axes
)
else
sh
for
i
,
sh
in
enumerate
(
x
.
type
.
shape
)
]
return
Apply
(
self
,
[
x
],
[
x
.
type
.
clone
(
shape
=
shape
)()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
out
[
0
]
=
x
def
grad
(
self
,
inp
,
grads
):
(
x
,)
=
inp
(
gz
,)
=
grads
# restore the broadcasting pattern of the input
return
[
specify_shape
(
gz
,
x
.
type
.
shape
)]
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
assert
len
(
ishapes
)
==
1
return
[
tuple
(
ishapes
[
0
])]
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
return
[
None
]
return
self
(
*
eval_points
,
return_list
=
True
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
(
iname
,)
=
inp
(
oname
,)
=
out
return
f
"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
def
c_code_cache_version
(
self
):
return
(
3
,)
def
unbroadcast
(
x
,
*
axes
):
"""
Mask static broadcastable dimensions of input as `None`
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The broadcastable dimensions of x that should be unbroadcasted.
Returns
-------
tensor
A aesara tensor, with static broadcastable dimensions masked as `None`
"""
x
=
as_tensor_variable
(
x
)
unbroadcasted_axes
=
[
axis
for
axis
in
axes
if
x
.
type
.
shape
[
axis
]
==
1
]
if
not
unbroadcasted_axes
:
return
x
return
Unbroadcast
(
*
unbroadcasted_axes
)(
x
)
aesara/tensor/subtensor_opt.py
浏览文件 @
7f8af9bc
...
@@ -14,7 +14,6 @@ from aesara.tensor.basic import (
...
@@ -14,7 +14,6 @@ from aesara.tensor.basic import (
ARange
,
ARange
,
Join
,
Join
,
MakeVector
,
MakeVector
,
Rebroadcast
,
ScalarFromTensor
,
ScalarFromTensor
,
TensorFromScalar
,
TensorFromScalar
,
alloc
,
alloc
,
...
@@ -50,9 +49,11 @@ from aesara.tensor.math import (
...
@@ -50,9 +49,11 @@ from aesara.tensor.math import (
from
aesara.tensor.shape
import
(
from
aesara.tensor.shape
import
(
Shape
,
Shape
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
shape_padleft
,
shape_padleft
,
shape_tuple
,
shape_tuple
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
aesara.tensor.sharedvar
import
TensorSharedVariable
from
aesara.tensor.sharedvar
import
TensorSharedVariable
from
aesara.tensor.subtensor
import
(
from
aesara.tensor.subtensor
import
(
...
@@ -370,7 +371,7 @@ def local_subtensor_lift(fgraph, node):
...
@@ -370,7 +371,7 @@ def local_subtensor_lift(fgraph, node):
Handles the following unary ops:
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => re
broadcast(x[idx])
Unbroadcast(x)[idx] => Un
broadcast(x[idx])
"""
"""
if
isinstance
(
node
.
op
,
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
...
@@ -429,34 +430,34 @@ def local_subtensor_lift(fgraph, node):
...
@@ -429,34 +430,34 @@ def local_subtensor_lift(fgraph, node):
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Rebroadcast
):
if
isinstance
(
u
.
owner
.
op
,
Unbroadcast
):
# make sure that Rebroadcast has only 1 input
assert
len
(
u
.
owner
.
inputs
)
==
1
# Subtensor might reduce dim., adapt broadcast pattern accordingly
# Subtensor might reduce dim., adapt broadcast pattern accordingly
new_axis
=
[]
old_axes
=
u
.
owner
.
op
.
axes
new_axes
=
[]
# loop through indices being subtensor-ed
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
# j indexes broadcastable pattern after subtensor
j
=
0
j
=
0
for
(
i
,
x
)
in
enumerate
(
node
.
op
.
idx_list
):
for
(
i
,
x
)
in
enumerate
(
node
.
op
.
idx_list
):
# if its not a slice, it will reduce the dimension, should
# if it
i
s not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
# not appear in the broascastable dimensions
if
isinstance
(
x
,
slice
):
if
isinstance
(
x
,
slice
):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
if
i
in
old_axes
:
new_axes
.
append
(
j
)
j
+=
1
j
+=
1
# now keep the broadcastable pattern of all
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
# items not appearing in subtensor list
for
i
in
range
(
len
(
node
.
op
.
idx_list
),
len
(
u
.
broadcastable
)):
for
i
in
range
(
len
(
node
.
op
.
idx_list
),
len
(
u
.
broadcastable
)):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
if
i
in
old_axes
:
new_axes
.
append
(
j
)
j
+=
1
j
+=
1
subt_x
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
node
.
inputs
[
1
:])
subt_x
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
node
.
inputs
[
1
:])
# Copy over previous output stacktrace
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
subt_x
)
copy_stack_trace
(
node
.
outputs
[
0
],
subt_x
)
rbcast_subt_x
=
Rebroadcast
(
*
new_axis
)(
subt_x
)
rbcast_subt_x
=
unbroadcast
(
subt_x
,
*
new_axes
)
# Copy over previous output stacktrace
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
rbcast_subt_x
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
rbcast_subt_x
)
...
...
tests/link/test_jax.py
浏览文件 @
7f8af9bc
...
@@ -39,7 +39,7 @@ from aesara.tensor.math import sum as at_sum
...
@@ -39,7 +39,7 @@ from aesara.tensor.math import sum as at_sum
from
aesara.tensor.nnet.basic
import
SoftmaxGrad
from
aesara.tensor.nnet.basic
import
SoftmaxGrad
from
aesara.tensor.random.basic
import
RandomVariable
,
normal
from
aesara.tensor.random.basic
import
RandomVariable
,
normal
from
aesara.tensor.random.utils
import
RandomStream
from
aesara.tensor.random.utils
import
RandomStream
from
aesara.tensor.shape
import
Shape
,
Shape_i
,
SpecifyShape
,
reshape
from
aesara.tensor.shape
import
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
,
reshape
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
dscalar
,
dscalar
,
dvector
,
dvector
,
...
@@ -201,20 +201,11 @@ def test_jax_compile_ops():
...
@@ -201,20 +201,11 @@ def test_jax_compile_ops():
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
x_fg
,
[])
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x
=
at
.
Rebroadcast
((
0
,
False
),
(
1
,
True
),
(
2
,
False
)
)(
at
.
as_tensor_variable
(
x_np
))
x
=
Unbroadcast
(
0
,
2
)(
at
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
x_fg
,
[])
with
config
.
change_flags
(
compute_test_value
=
"off"
):
x
=
at
.
Rebroadcast
((
0
,
True
),
(
1
,
False
),
(
2
,
False
))(
at
.
as_tensor_variable
(
x_np
)
)
x_fg
=
FunctionGraph
([],
[
x
])
with
pytest
.
raises
(
ValueError
):
compare_jax_and_py
(
x_fg
,
[])
x
=
ViewOp
()(
at
.
as_tensor_variable
(
x_np
))
x
=
ViewOp
()(
at
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
x_fg
=
FunctionGraph
([],
[
x
])
...
...
tests/link/test_numba.py
浏览文件 @
7f8af9bc
...
@@ -40,7 +40,7 @@ from aesara.tensor import extra_ops, nlinalg, slinalg
...
@@ -40,7 +40,7 @@ from aesara.tensor import extra_ops, nlinalg, slinalg
from
aesara.tensor
import
subtensor
as
at_subtensor
from
aesara.tensor
import
subtensor
as
at_subtensor
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
aesara.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
class
MyType
(
Type
):
class
MyType
(
Type
):
...
@@ -769,39 +769,18 @@ def test_ScalarFromTensor(v):
...
@@ -769,39 +769,18 @@ def test_ScalarFromTensor(v):
)
)
@pytest.mark.parametrize
(
def
test_Unbroadcast
():
"v, axis, fails"
,
v
=
set_test_value
(
at
.
row
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
))
[
g
=
Unbroadcast
(
0
)(
v
)
(
set_test_value
(
at
.
matrix
(),
np
.
array
([[
1.0
]],
dtype
=
config
.
floatX
)),
[(
0
,
True
),
(
1
,
True
)],
False
,
),
(
set_test_value
(
at
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
[(
0
,
True
),
(
1
,
False
)],
False
,
),
(
set_test_value
(
at
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
[(
0
,
True
),
(
1
,
True
)],
True
,
),
],
)
def
test_Rebroadcast
(
v
,
axis
,
fails
):
g
=
atb
.
Rebroadcast
(
*
axis
)(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
not
fails
else
pytest
.
raises
(
ValueError
)
compare_numba_and_py
(
with
cm
:
g_fg
,
compare_numba_and_py
(
[
g_fg
,
i
.
tag
.
test_value
[
for
i
in
g_fg
.
inputs
i
.
tag
.
test_value
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
for
i
in
g_fg
.
inputs
],
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
)
],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
...
tests/scan/test_printing.py
浏览文件 @
7f8af9bc
...
@@ -36,7 +36,7 @@ def test_debugprint_sitsot():
...
@@ -36,7 +36,7 @@ def test_debugprint_sitsot():
| | | | | |k [id D]
| | | | | |k [id D]
| | | | | |Subtensor{int64} [id H]
| | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I]
| | | | | |Shape [id I]
| | | | | | |
Rebroadcast{(0, False)
} [id J]
| | | | | | |
Unbroadcast{0
} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M]
| | | | | | |A [id M]
...
@@ -45,9 +45,9 @@ def test_debugprint_sitsot():
...
@@ -45,9 +45,9 @@ def test_debugprint_sitsot():
| | | | | |ScalarConstant{0} [id P]
| | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q]
| | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R]
| | | | |Shape [id R]
| | | | | |
Rebroadcast{(0, False)
} [id J]
| | | | | |
Unbroadcast{0
} [id J]
| | | | |ScalarConstant{1} [id S]
| | | | |ScalarConstant{1} [id S]
| | | |
Rebroadcast{(0, False)
} [id J]
| | | |
Unbroadcast{0
} [id J]
| | | |ScalarFromTensor [id T]
| | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H]
| | | |Subtensor{int64} [id H]
| | |A [id M] (outer_in_non_seqs-0)
| | |A [id M] (outer_in_non_seqs-0)
...
@@ -91,7 +91,7 @@ def test_debugprint_sitsot_no_extra_info():
...
@@ -91,7 +91,7 @@ def test_debugprint_sitsot_no_extra_info():
| | | | | |k [id D]
| | | | | |k [id D]
| | | | | |Subtensor{int64} [id H]
| | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I]
| | | | | |Shape [id I]
| | | | | | |
Rebroadcast{(0, False)
} [id J]
| | | | | | |
Unbroadcast{0
} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M]
| | | | | | |A [id M]
...
@@ -100,9 +100,9 @@ def test_debugprint_sitsot_no_extra_info():
...
@@ -100,9 +100,9 @@ def test_debugprint_sitsot_no_extra_info():
| | | | | |ScalarConstant{0} [id P]
| | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q]
| | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R]
| | | | |Shape [id R]
| | | | | |
Rebroadcast{(0, False)
} [id J]
| | | | | |
Unbroadcast{0
} [id J]
| | | | |ScalarConstant{1} [id S]
| | | | |ScalarConstant{1} [id S]
| | | |
Rebroadcast{(0, False)
} [id J]
| | | |
Unbroadcast{0
} [id J]
| | | |ScalarFromTensor [id T]
| | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H]
| | | |Subtensor{int64} [id H]
| | |A [id M]
| | |A [id M]
...
@@ -261,7 +261,7 @@ def test_debugprint_nested_scans():
...
@@ -261,7 +261,7 @@ def test_debugprint_nested_scans():
> | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK]
> | | | | | | |Shape [id BK]
> | | | | | | | |
Rebroadcast{(0, False)
} [id BL]
> | | | | | | | |
Unbroadcast{0
} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
...
@@ -270,9 +270,9 @@ def test_debugprint_nested_scans():
...
@@ -270,9 +270,9 @@ def test_debugprint_nested_scans():
> | | | | | | |ScalarConstant{0} [id BR]
> | | | | | | |ScalarConstant{0} [id BR]
> | | | | | |Subtensor{int64} [id BS]
> | | | | | |Subtensor{int64} [id BS]
> | | | | | |Shape [id BT]
> | | | | | |Shape [id BT]
> | | | | | | |
Rebroadcast{(0, False)
} [id BL]
> | | | | | | |
Unbroadcast{0
} [id BL]
> | | | | | |ScalarConstant{1} [id BU]
> | | | | | |ScalarConstant{1} [id BU]
> | | | | |
Rebroadcast{(0, False)
} [id BL]
> | | | | |
Unbroadcast{0
} [id BL]
> | | | | |ScalarFromTensor [id BV]
> | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ]
> | | | | |Subtensor{int64} [id BJ]
> | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
...
@@ -350,7 +350,7 @@ def test_debugprint_nested_scans():
...
@@ -350,7 +350,7 @@ def test_debugprint_nested_scans():
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BL]
> | | | | | | |Subtensor{int64} [id BL]
> | | | | | | |Shape [id BM]
> | | | | | | |Shape [id BM]
> | | | | | | | |
Rebroadcast{(0, False)
} [id BN]
> | | | | | | | |
Unbroadcast{0
} [id BN]
> | | | | | | | |InplaceDimShuffle{x,0} [id BO]
> | | | | | | | |InplaceDimShuffle{x,0} [id BO]
> | | | | | | | |Elemwise{second,no_inplace} [id BP]
> | | | | | | | |Elemwise{second,no_inplace} [id BP]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
...
@@ -359,9 +359,9 @@ def test_debugprint_nested_scans():
...
@@ -359,9 +359,9 @@ def test_debugprint_nested_scans():
> | | | | | | |ScalarConstant{0} [id BS]
> | | | | | | |ScalarConstant{0} [id BS]
> | | | | | |Subtensor{int64} [id BT]
> | | | | | |Subtensor{int64} [id BT]
> | | | | | |Shape [id BU]
> | | | | | |Shape [id BU]
> | | | | | | |
Rebroadcast{(0, False)
} [id BN]
> | | | | | | |
Unbroadcast{0
} [id BN]
> | | | | | |ScalarConstant{1} [id BV]
> | | | | | |ScalarConstant{1} [id BV]
> | | | | |
Rebroadcast{(0, False)
} [id BN]
> | | | | |
Unbroadcast{0
} [id BN]
> | | | | |ScalarFromTensor [id BW]
> | | | | |ScalarFromTensor [id BW]
> | | | | |Subtensor{int64} [id BL]
> | | | | |Subtensor{int64} [id BL]
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
...
@@ -487,7 +487,7 @@ def test_debugprint_mitmot():
...
@@ -487,7 +487,7 @@ def test_debugprint_mitmot():
| | | | | | | |k [id G]
| | | | | | | |k [id G]
| | | | | | | |Subtensor{int64} [id K]
| | | | | | | |Subtensor{int64} [id K]
| | | | | | | |Shape [id L]
| | | | | | | |Shape [id L]
| | | | | | | | |
Rebroadcast{(0, False)
} [id M]
| | | | | | | | |
Unbroadcast{0
} [id M]
| | | | | | | | |InplaceDimShuffle{x,0} [id N]
| | | | | | | | |InplaceDimShuffle{x,0} [id N]
| | | | | | | | |Elemwise{second,no_inplace} [id O]
| | | | | | | | |Elemwise{second,no_inplace} [id O]
| | | | | | | | |A [id P]
| | | | | | | | |A [id P]
...
@@ -496,9 +496,9 @@ def test_debugprint_mitmot():
...
@@ -496,9 +496,9 @@ def test_debugprint_mitmot():
| | | | | | | |ScalarConstant{0} [id S]
| | | | | | | |ScalarConstant{0} [id S]
| | | | | | |Subtensor{int64} [id T]
| | | | | | |Subtensor{int64} [id T]
| | | | | | |Shape [id U]
| | | | | | |Shape [id U]
| | | | | | | |
Rebroadcast{(0, False)
} [id M]
| | | | | | | |
Unbroadcast{0
} [id M]
| | | | | | |ScalarConstant{1} [id V]
| | | | | | |ScalarConstant{1} [id V]
| | | | | |
Rebroadcast{(0, False)
} [id M]
| | | | | |
Unbroadcast{0
} [id M]
| | | | | |ScalarFromTensor [id W]
| | | | | |ScalarFromTensor [id W]
| | | | | |Subtensor{int64} [id K]
| | | | | |Subtensor{int64} [id K]
| | | | |A [id P] (outer_in_non_seqs-0)
| | | | |A [id P] (outer_in_non_seqs-0)
...
...
tests/tensor/test_basic.py
浏览文件 @
7f8af9bc
...
@@ -34,7 +34,6 @@ from aesara.tensor.basic import (
...
@@ -34,7 +34,6 @@ from aesara.tensor.basic import (
Join
,
Join
,
MakeVector
,
MakeVector
,
PermuteRowElements
,
PermuteRowElements
,
Rebroadcast
,
ScalarFromTensor
,
ScalarFromTensor
,
Split
,
Split
,
TensorFromScalar
,
TensorFromScalar
,
...
@@ -86,7 +85,6 @@ from aesara.tensor.basic import (
...
@@ -86,7 +85,6 @@ from aesara.tensor.basic import (
triu
,
triu
,
triu_indices
,
triu_indices
,
triu_indices_from
,
triu_indices_from
,
unbroadcast
,
vertical_stack
,
vertical_stack
,
zeros_like
,
zeros_like
,
)
)
...
@@ -104,7 +102,6 @@ from aesara.tensor.type import (
...
@@ -104,7 +102,6 @@ from aesara.tensor.type import (
dscalar
,
dscalar
,
dscalars
,
dscalars
,
dtensor3
,
dtensor3
,
dtensor4
,
dvector
,
dvector
,
fmatrix
,
fmatrix
,
fscalar
,
fscalar
,
...
@@ -337,7 +334,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
...
@@ -337,7 +334,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
)
)
# Partial un
broadcast of a dimshuffled input
# Partial unbroadcast of a dimshuffled input
TestAllocDimshuffleGradBroadcast
=
makeBroadcastTester
(
TestAllocDimshuffleGradBroadcast
=
makeBroadcastTester
(
name
=
"Allocb4GradTester"
,
name
=
"Allocb4GradTester"
,
op
=
lambda
x
:
alloc
(
x
.
dimshuffle
(
"x"
,
"x"
,
0
),
1
,
s2
,
s3
),
op
=
lambda
x
:
alloc
(
x
.
dimshuffle
(
"x"
,
"x"
,
0
),
1
,
s2
,
s3
),
...
@@ -3223,80 +3220,6 @@ class TestLongTensor:
...
@@ -3223,80 +3220,6 @@ class TestLongTensor:
constant
()[[
val
,
val
]]
constant
()[[
val
,
val
]]
class
TestBroadcast
:
def
test_unbroadcast
(
self
):
# test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op
x
=
matrix
()
assert
unbroadcast
(
x
,
0
)
is
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
x
assert
unbroadcast
(
x
,
0
,
1
)
is
x
x
=
row
()
assert
unbroadcast
(
x
,
0
)
is
not
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
not
x
assert
unbroadcast
(
x
,
0
,
1
)
is
not
x
# The first broadcast is remove the broadcast, so the second
# should not make one
assert
unbroadcast
(
unbroadcast
(
x
,
0
),
0
)
.
owner
.
inputs
[
0
]
is
x
# Test that consecutive Rebroadcast op are fused
x
=
TensorType
(
dtype
=
"float64"
,
shape
=
(
True
,
True
))()
assert
unbroadcast
(
unbroadcast
(
x
,
1
),
0
)
.
owner
.
inputs
[
0
]
is
x
def
test_infer_shape
(
self
):
x
=
matrix
()
y
=
unbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
2
,
5
),
dtype
=
config
.
floatX
))
==
[
2
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
3
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
2
]
.
op
,
MakeVector
)
x
=
row
()
y
=
unbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
1
,
5
),
dtype
=
config
.
floatX
))
==
[
1
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
MakeVector
)
class
TestRebroadcast
(
utt
.
InferShapeTester
):
def
test_rebroadcast
(
self
):
rng
=
np
.
random
.
default_rng
(
3453
)
# Rebroadcast
adtens4
=
dtensor4
()
adict
=
[(
0
,
False
),
(
1
,
True
),
(
2
,
False
),
(
3
,
True
)]
adtens4_val
=
rng
.
random
((
2
,
1
,
3
,
1
))
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
(
[
adtens4
],
[
Rebroadcast
(
*
adict
)(
adtens4
)],
[
adtens4_val
],
Rebroadcast
,
warn
=
False
,
)
adtens4_bro
=
TensorType
(
"float64"
,
(
True
,
True
,
True
,
False
))()
bdict
=
[(
0
,
True
),
(
1
,
False
),
(
2
,
False
),
(
3
,
False
)]
adtens4_bro_val
=
rng
.
random
((
1
,
1
,
1
,
3
))
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
(
[
adtens4_bro
],
[
Rebroadcast
(
*
bdict
)(
adtens4_bro
)],
[
adtens4_bro_val
],
Rebroadcast
,
)
def
test_len
():
def
test_len
():
for
shape_
in
[(
5
,),
(
3
,
4
),
(
7
,
4
,
6
)]:
for
shape_
in
[(
5
,),
(
3
,
4
),
(
7
,
4
,
6
)]:
x
=
tensor
(
dtype
=
"floatX"
,
shape
=
(
False
,)
*
len
(
shape_
))
x
=
tensor
(
dtype
=
"floatX"
,
shape
=
(
False
,)
*
len
(
shape_
))
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
7f8af9bc
...
@@ -28,7 +28,6 @@ from aesara.tensor.basic import (
...
@@ -28,7 +28,6 @@ from aesara.tensor.basic import (
Alloc
,
Alloc
,
Join
,
Join
,
MakeVector
,
MakeVector
,
Rebroadcast
,
ScalarFromTensor
,
ScalarFromTensor
,
Split
,
Split
,
TensorFromScalar
,
TensorFromScalar
,
...
@@ -40,7 +39,6 @@ from aesara.tensor.basic import (
...
@@ -40,7 +39,6 @@ from aesara.tensor.basic import (
)
)
from
aesara.tensor.basic_opt
import
(
from
aesara.tensor.basic_opt
import
(
ShapeFeature
,
ShapeFeature
,
apply_rebroadcast_opt
,
assert_op
,
assert_op
,
local_alloc_sink_dimshuffle
,
local_alloc_sink_dimshuffle
,
local_dimshuffle_lift
,
local_dimshuffle_lift
,
...
@@ -92,9 +90,11 @@ from aesara.tensor.shape import (
...
@@ -92,9 +90,11 @@ from aesara.tensor.shape import (
Reshape
,
Reshape
,
Shape_i
,
Shape_i
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
reshape
,
reshape
,
shape
,
shape
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
aesara.tensor.subtensor
import
(
from
aesara.tensor.subtensor
import
(
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
...
@@ -1898,18 +1898,46 @@ class TestTile:
...
@@ -1898,18 +1898,46 @@ class TestTile:
f
(
data
)
f
(
data
)
class
TestRebroadcast
:
class
TestUnbroadcast
:
def
test_local_useless_rebroadcast
(
self
):
def
setup_method
(
self
):
mode
=
get_default_mode
()
.
including
(
"canonicalize"
)
self
.
mode
=
get_default_mode
()
.
including
(
"canonicalize"
)
v1
=
vector
()
v2
=
vector
()
def
test_local_useless_unbroadcast
(
self
):
j
=
at
.
join
(
0
,
v1
,
v2
)
x1
=
tensor
(
"float64"
,
shape
=
(
1
,
2
))
f
=
function
([
v1
,
v2
],
j
,
mode
=
mode
)
x2
=
tensor
(
"float64"
,
shape
=
(
2
,
1
))
f
([
1
,
2
],
[
3
,
4
,
5
])
unbroadcast_op
=
Unbroadcast
(
0
)
e
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
([
n
for
n
in
e
if
isinstance
(
n
.
op
,
Rebroadcast
)])
==
0
f
=
function
([
x1
],
unbroadcast_op
(
x1
),
mode
=
self
.
mode
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
f
.
maker
.
fgraph
.
toposort
())
==
1
)
f
=
function
([
x2
],
unbroadcast_op
(
x2
),
mode
=
self
.
mode
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
f
.
maker
.
fgraph
.
toposort
())
==
0
)
def
test_local_unbroadcast_lift
(
self
):
x
=
tensor
(
"float64"
,
shape
=
(
1
,
1
))
y
=
unbroadcast
(
at
.
exp
(
unbroadcast
(
x
,
0
)),
1
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
FunctionGraph
([
x
],
[
y
],
copy_inputs
=
False
)
.
toposort
()
)
==
2
)
f
=
function
([
x
],
y
,
mode
=
self
.
mode
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
f
.
maker
.
fgraph
.
toposort
())
==
1
)
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
np
.
testing
.
assert_almost_equal
(
f
([[
1
]]),
np
.
exp
([[
1
]])
)
class
TestUselessElemwise
:
class
TestUselessElemwise
:
...
@@ -3167,21 +3195,6 @@ def test_local_useless_alloc():
...
@@ -3167,21 +3195,6 @@ def test_local_useless_alloc():
assert
isinstance
(
topo
[
-
1
]
.
op
,
Alloc
)
assert
isinstance
(
topo
[
-
1
]
.
op
,
Alloc
)
def
test_apply_rebroadcast_opt
():
# Test the `Elemwise` case in `local_rebroadcast_lift` with `fgraph=None`.
# This is called by in `apply_rebroadcast_opt`.
a
=
vector
(
dtype
=
"float32"
)
b
=
tensor
(
"float64"
,
[
True
])
x
=
b
.
astype
(
a
.
dtype
)
broadcastable
=
(
False
,)
axis
=
[(
i
,
broadcastable
[
i
])
for
i
in
range
(
len
(
broadcastable
))]
rval
=
Rebroadcast
(
*
axis
)(
x
)
res
=
apply_rebroadcast_opt
(
rval
)
assert
res
is
rval
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
...
...
tests/tensor/test_shape.py
浏览文件 @
7f8af9bc
...
@@ -17,12 +17,14 @@ from aesara.tensor.shape import (
...
@@ -17,12 +17,14 @@ from aesara.tensor.shape import (
Reshape
,
Reshape
,
Shape_i
,
Shape_i
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
_specify_shape
,
_specify_shape
,
reshape
,
reshape
,
shape
,
shape
,
shape_i
,
shape_i
,
specify_broadcastable
,
specify_broadcastable
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
aesara.tensor.subtensor
import
Subtensor
from
aesara.tensor.subtensor
import
Subtensor
from
aesara.tensor.type
import
(
from
aesara.tensor.type
import
(
...
@@ -36,6 +38,7 @@ from aesara.tensor.type import (
...
@@ -36,6 +38,7 @@ from aesara.tensor.type import (
lscalar
,
lscalar
,
matrix
,
matrix
,
scalar
,
scalar
,
tensor
,
tensor3
,
tensor3
,
vector
,
vector
,
)
)
...
@@ -594,3 +597,63 @@ def test_get_vector_length():
...
@@ -594,3 +597,63 @@ def test_get_vector_length():
# Test `SpecifyShape`
# Test `SpecifyShape`
x
=
specify_shape
(
ivector
(),
(
10
,))
x
=
specify_shape
(
ivector
(),
(
10
,))
assert
get_vector_length
(
x
)
==
10
assert
get_vector_length
(
x
)
==
10
class
TestUnbroadcast
:
def
test_basic
(
self
):
x
=
matrix
()
assert
unbroadcast
(
x
,
0
)
is
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
x
assert
unbroadcast
(
x
,
0
,
1
)
is
x
x
=
row
()
assert
unbroadcast
(
x
,
0
)
is
not
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
not
x
assert
unbroadcast
(
x
,
0
,
1
)
is
not
x
assert
unbroadcast
(
unbroadcast
(
x
,
0
),
0
)
.
owner
.
inputs
[
0
]
is
x
def
test_infer_shape
(
self
):
x
=
matrix
()
y
=
unbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
2
,
5
),
dtype
=
config
.
floatX
))
==
[
2
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
3
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
2
]
.
op
,
MakeVector
)
x
=
row
()
y
=
unbroadcast
(
x
,
0
)
f
=
aesara
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
1
,
5
),
dtype
=
config
.
floatX
))
==
[
1
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
MakeVector
)
def
test_error_checks
(
self
):
with
pytest
.
raises
(
TypeError
,
match
=
"needs integer axes"
):
Unbroadcast
(
0.0
)
with
pytest
.
raises
(
ValueError
,
match
=
"^Trying to unbroadcast"
):
Unbroadcast
(
1
)(
vector
())
class
TestUnbroadcastInferShape
(
utt
.
InferShapeTester
):
def
test_basic
(
self
):
rng
=
np
.
random
.
default_rng
(
3453
)
adtens4
=
tensor
(
"float64"
,
shape
=
(
1
,
1
,
1
,
None
))
adtens4_val
=
rng
.
random
((
1
,
1
,
1
,
3
))
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
(
[
adtens4
],
[
Unbroadcast
(
0
,
2
)(
adtens4
)],
[
adtens4_val
],
Unbroadcast
,
warn
=
False
,
)
tests/tensor/test_subtensor_opt.py
浏览文件 @
7f8af9bc
...
@@ -16,16 +16,10 @@ from aesara.graph.optdb import OptimizationQuery
...
@@ -16,16 +16,10 @@ from aesara.graph.optdb import OptimizationQuery
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.raise_op
import
Assert
from
aesara.raise_op
import
Assert
from
aesara.tensor
import
inplace
from
aesara.tensor
import
inplace
from
aesara.tensor.basic
import
(
from
aesara.tensor.basic
import
Alloc
,
MakeVector
,
_convert_to_int8
,
make_vector
Alloc
,
MakeVector
,
Rebroadcast
,
_convert_to_int8
,
make_vector
,
)
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.math
import
Dot
,
add
,
dot
,
exp
,
sqr
from
aesara.tensor.math
import
Dot
,
add
,
dot
,
exp
,
sqr
from
aesara.tensor.shape
import
SpecifyShape
,
_shape
,
shape
,
specify_shape
from
aesara.tensor.shape
import
SpecifyShape
,
Unbroadcast
,
_shape
,
shape
,
specify_shape
from
aesara.tensor.subtensor
import
(
from
aesara.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
...
@@ -843,61 +837,61 @@ class TestLocalSubtensorLift:
...
@@ -843,61 +837,61 @@ class TestLocalSubtensorLift:
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
def
test_basic_8
(
self
):
def
test_basic_8
(
self
):
# Test that Subtensor(
Re
broadcast(x)) gets optimized into
# Test that Subtensor(
Un
broadcast(x)) gets optimized into
#
Re
broadcast(Subtensor(x)).
#
Un
broadcast(Subtensor(x)).
# test basic case
# test basic case
x
=
matrix
(
"x"
)
x
=
row
(
"x"
)
xval
=
np
.
random
.
random
((
1
,
10
))
.
astype
(
config
.
floatX
)
xval
=
np
.
random
.
random
((
1
,
10
))
.
astype
(
config
.
floatX
)
assert
x
.
broadcastable
==
(
Fals
e
,
False
)
assert
x
.
broadcastable
==
(
Tru
e
,
False
)
newx
=
Rebroadcast
((
0
,
True
),
(
1
,
False
)
)(
x
)
newx
=
Unbroadcast
(
0
)(
x
)
assert
newx
.
broadcastable
==
(
Tru
e
,
False
)
assert
newx
.
broadcastable
==
(
Fals
e
,
False
)
f1
=
function
([
x
],
newx
[:
2
,
:
5
],
mode
=
mode_opt
)
f1
=
function
([
x
],
newx
[:
2
,
:
5
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f1
,
ops_to_check
=
[
Subtensor
,
Re
broadcast
])
assert
check_stack_trace
(
f1
,
ops_to_check
=
[
Subtensor
,
Un
broadcast
])
prog
=
f1
.
maker
.
fgraph
.
toposort
()
prog
=
f1
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Re
broadcast
)
assert
isinstance
(
prog
[
1
]
.
op
,
Un
broadcast
)
assert
(
f1
(
xval
)
==
xval
[:
2
,
:
5
])
.
all
()
assert
(
f1
(
xval
)
==
xval
[:
2
,
:
5
])
.
all
()
# corner case 1:
re
broadcast changes dims which are dropped through subtensor
# corner case 1:
Un
broadcast changes dims which are dropped through subtensor
y
=
tensor
4
(
"x"
)
y
=
tensor
(
"float64"
,
shape
=
(
1
,
10
,
1
,
3
),
name
=
"x"
)
yval
=
np
.
random
.
random
((
1
,
10
,
1
,
3
))
.
astype
(
config
.
floatX
)
yval
=
np
.
random
.
random
((
1
,
10
,
1
,
3
))
.
astype
(
config
.
floatX
)
assert
y
.
broadcastable
==
(
False
,
False
,
Fals
e
,
False
)
assert
y
.
broadcastable
==
(
True
,
False
,
Tru
e
,
False
)
newy
=
Rebroadcast
((
0
,
True
),
(
2
,
True
)
)(
y
)
newy
=
Unbroadcast
(
0
,
2
)(
y
)
assert
newy
.
broadcastable
==
(
True
,
False
,
Tru
e
,
False
)
assert
newy
.
broadcastable
==
(
False
,
False
,
Fals
e
,
False
)
f2
=
function
([
y
],
newy
[:,
3
,
0
,
:],
mode
=
mode_opt
)
f2
=
function
([
y
],
newy
[:,
3
,
0
,
:],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f2
,
ops_to_check
=
[
Subtensor
,
Re
broadcast
])
assert
check_stack_trace
(
f2
,
ops_to_check
=
[
Subtensor
,
Un
broadcast
])
prog
=
f2
.
maker
.
fgraph
.
toposort
()
prog
=
f2
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Re
broadcast
)
assert
isinstance
(
prog
[
1
]
.
op
,
Un
broadcast
)
assert
(
f2
(
yval
)
==
yval
[:,
3
,
0
,
:])
.
all
()
assert
(
f2
(
yval
)
==
yval
[:,
3
,
0
,
:])
.
all
()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3
=
function
([
y
],
newy
[:,
3
,
0
],
mode
=
mode_opt
)
f3
=
function
([
y
],
newy
[:,
3
,
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f3
,
ops_to_check
=
[
Subtensor
,
Re
broadcast
])
assert
check_stack_trace
(
f3
,
ops_to_check
=
[
Subtensor
,
Un
broadcast
])
prog
=
f3
.
maker
.
fgraph
.
toposort
()
prog
=
f3
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Re
broadcast
)
assert
isinstance
(
prog
[
1
]
.
op
,
Un
broadcast
)
assert
(
f3
(
yval
)
==
yval
[:,
3
,
0
])
.
all
()
assert
(
f3
(
yval
)
==
yval
[:,
3
,
0
])
.
all
()
# corner case 3: subtensor idx_list is shorter than
re
broadcast.axis
# corner case 3: subtensor idx_list is shorter than
Un
broadcast.axis
z
=
tensor
4
(
"x"
)
z
=
tensor
(
"float64"
,
shape
=
(
4
,
10
,
3
,
1
),
name
=
"x"
)
zval
=
np
.
random
.
random
((
4
,
10
,
3
,
1
))
.
astype
(
config
.
floatX
)
zval
=
np
.
random
.
random
((
4
,
10
,
3
,
1
))
.
astype
(
config
.
floatX
)
assert
z
.
broadcastable
==
(
False
,
False
,
False
,
Fals
e
)
assert
z
.
broadcastable
==
(
False
,
False
,
False
,
Tru
e
)
newz
=
Rebroadcast
((
3
,
True
)
)(
z
)
newz
=
Unbroadcast
(
3
)(
z
)
assert
newz
.
broadcastable
==
(
False
,
False
,
False
,
Tru
e
)
assert
newz
.
broadcastable
==
(
False
,
False
,
False
,
Fals
e
)
f4
=
function
([
z
],
newz
[:,
3
,
0
],
mode
=
mode_opt
)
f4
=
function
([
z
],
newz
[:,
3
,
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f4
,
ops_to_check
=
[
Subtensor
,
Re
broadcast
])
assert
check_stack_trace
(
f4
,
ops_to_check
=
[
Subtensor
,
Un
broadcast
])
prog
=
f4
.
maker
.
fgraph
.
toposort
()
prog
=
f4
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Re
broadcast
)
assert
isinstance
(
prog
[
1
]
.
op
,
Un
broadcast
)
assert
(
f4
(
zval
)
==
zval
[:,
3
,
0
])
.
all
()
assert
(
f4
(
zval
)
==
zval
[:,
3
,
0
])
.
all
()
...
...
tests/test_rop.py
浏览文件 @
7f8af9bc
...
@@ -26,6 +26,7 @@ from aesara.graph.op import Op
...
@@ -26,6 +26,7 @@ from aesara.graph.op import Op
from
aesara.tensor.math
import
argmax
,
dot
from
aesara.tensor.math
import
argmax
,
dot
from
aesara.tensor.math
import
max
as
at_max
from
aesara.tensor.math
import
max
as
at_max
from
aesara.tensor.nnet
import
conv
,
conv2d
from
aesara.tensor.nnet
import
conv
,
conv2d
from
aesara.tensor.shape
import
unbroadcast
from
aesara.tensor.signal.pool
import
Pool
from
aesara.tensor.signal.pool
import
Pool
from
aesara.tensor.type
import
TensorType
,
matrix
,
vector
from
aesara.tensor.type
import
TensorType
,
matrix
,
vector
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
...
@@ -237,11 +238,11 @@ class TestRopLop(RopLopChecker):
...
@@ -237,11 +238,11 @@ class TestRopLop(RopLopChecker):
# vector
# vector
self
.
check_rop_lop
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
)
.
sum
(
axis
=
0
),
(
4
,))
self
.
check_rop_lop
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
)
.
sum
(
axis
=
0
),
(
4
,))
def
test_
re
broadcast
(
self
):
def
test_
un
broadcast
(
self
):
# I need the sum, because the setup expects the output to be a
# I need the sum, because the setup expects the output to be a
# vector
# vector
self
.
check_rop_lop
(
self
.
check_rop_lop
(
at
.
unbroadcast
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
),
0
)
.
sum
(
axis
=
1
),
(
1
,)
unbroadcast
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
),
0
)
.
sum
(
axis
=
1
),
(
1
,)
)
)
@pytest.mark.slow
@pytest.mark.slow
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论