Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
63f52536
提交
63f52536
authored
8月 22, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split aesara.tensor.rewriting.basic rewrites by their aesara.tensor modules
上级
9704ed42
隐藏空白字符变更
内嵌
并排
正在显示
21 个修改的文件
包含
4458 行增加
和
4282 行删除
+4458
-4282
builders.py
aesara/compile/builders.py
+1
-1
rewriting.py
aesara/scan/rewriting.py
+9
-8
basic.py
aesara/tensor/basic.py
+18
-2
basic_opt.py
aesara/tensor/basic_opt.py
+3
-0
blas.py
aesara/tensor/blas.py
+1
-1
__init__.py
aesara/tensor/rewriting/__init__.py
+3
-0
basic.py
aesara/tensor/rewriting/basic.py
+31
-2299
elemwise.py
aesara/tensor/rewriting/elemwise.py
+946
-0
extra_ops.py
aesara/tensor/rewriting/extra_ops.py
+177
-0
math.py
aesara/tensor/rewriting/math.py
+1
-2
shape.py
aesara/tensor/rewriting/shape.py
+1193
-0
utils.py
aesara/tensor/utils.py
+3
-1
test_builders.py
tests/compile/test_builders.py
+1
-1
test_basic.py
tests/tensor/random/test_basic.py
+1
-1
test_basic.py
tests/tensor/rewriting/test_basic.py
+8
-1963
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1204
-0
test_extra_ops.py
tests/tensor/rewriting/test_extra_ops.py
+302
-0
test_math.py
tests/tensor/rewriting/test_math.py
+1
-1
test_shape.py
tests/tensor/rewriting/test_shape.py
+553
-0
test_elemwise.py
tests/tensor/test_elemwise.py
+1
-1
test_shape.py
tests/tensor/test_shape.py
+1
-1
没有找到文件。
aesara/compile/builders.py
浏览文件 @
63f52536
...
...
@@ -26,7 +26,7 @@ from aesara.graph.null_type import NullType
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.rewriting.basic
import
in2out
,
node_rewriter
from
aesara.graph.utils
import
MissingInputError
from
aesara.tensor.rewriting.
basic
import
ShapeFeature
from
aesara.tensor.rewriting.
shape
import
ShapeFeature
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
...
...
aesara/scan/rewriting.py
浏览文件 @
63f52536
...
...
@@ -45,8 +45,9 @@ from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
dot
,
maximum
,
minimum
from
aesara.tensor.rewriting
import
basic
as
basic_opt
from
aesara.tensor.rewriting
import
math
as
math_opt
from
aesara.tensor.rewriting.basic
import
constant_folding
,
local_useless_switch
from
aesara.tensor.rewriting.elemwise
import
local_upcast_elemwise_constant_inputs
from
aesara.tensor.rewriting.math
import
local_abs_merge
,
local_mul_switch_sink
from
aesara.tensor.shape
import
shape
from
aesara.tensor.subtensor
import
(
IncSubtensor
,
...
...
@@ -60,11 +61,11 @@ from aesara.tensor.var import TensorConstant, get_unique_value
list_opt_slice
=
[
math_opt
.
local_abs_merge
,
math_opt
.
local_mul_switch_sink
,
basic_opt
.
local_upcast_elemwise_constant_inputs
,
basic_opt
.
local_useless_switch
,
basic_opt
.
constant_folding
,
local_abs_merge
,
local_mul_switch_sink
,
local_upcast_elemwise_constant_inputs
,
local_useless_switch
,
constant_folding
,
]
...
...
@@ -2432,7 +2433,7 @@ scan_seqopt1.register(
scan_eqopt2
.
register
(
"constant_folding_for_scan2"
,
in2out
(
basic_opt
.
constant_folding
,
ignore_newtrees
=
True
),
in2out
(
constant_folding
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
)
...
...
aesara/tensor/basic.py
浏览文件 @
63f52536
...
...
@@ -29,7 +29,7 @@ from aesara.graph.type import Type
from
aesara.link.c.op
import
COp
from
aesara.link.c.params_type
import
ParamsType
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.printing
import
min_informative_str
,
pprint
from
aesara.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
from
aesara.raise_op
import
CheckAndRaise
,
assert_op
from
aesara.scalar
import
int32
from
aesara.scalar.basic
import
ScalarConstant
,
ScalarVariable
...
...
@@ -1335,7 +1335,8 @@ def infer_broadcastable(shape):
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from
aesara.tensor.rewriting.basic
import
ShapeFeature
,
topo_constant_folding
from
aesara.tensor.rewriting.basic
import
topo_constant_folding
from
aesara.tensor.rewriting.shape
import
ShapeFeature
def
check_type
(
s
):
if
s
.
type
.
dtype
in
integer_dtypes
:
...
...
@@ -1709,6 +1710,21 @@ class MakeVector(COp):
make_vector
=
MakeVector
()
class
MakeVectorPrinter
(
Printer
):
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print make_vector."
)
elif
isinstance
(
r
.
owner
.
op
,
MakeVector
):
with
set_precedence
(
pstate
):
s
=
[
pstate
.
pprinter
.
process
(
inp
)
for
inp
in
r
.
owner
.
inputs
]
return
f
"[{', '.join(s)}]"
else
:
raise
TypeError
(
"Can only print make_vector."
)
pprint
.
assign
(
MakeVector
,
MakeVectorPrinter
())
@_get_vector_length.register
(
MakeVector
)
def
_get_vector_length_MakeVector
(
op
,
var
):
return
len
(
var
.
owner
.
inputs
)
...
...
aesara/tensor/basic_opt.py
浏览文件 @
63f52536
...
...
@@ -8,3 +8,6 @@ warnings.warn(
)
from
aesara.tensor.rewriting.basic
import
*
# noqa: F401 E402 F403
from
aesara.tensor.rewriting.elemwise
import
*
# noqa: F401 E402 F403
from
aesara.tensor.rewriting.extra_ops
import
*
# noqa: F401 E402 F403
from
aesara.tensor.rewriting.shape
import
*
# noqa: F401 E402 F403
aesara/tensor/blas.py
浏览文件 @
63f52536
...
...
@@ -163,7 +163,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
add
,
mul
,
neg
,
sub
from
aesara.tensor.rewriting.
basic
import
local_dimshuffle_lift
from
aesara.tensor.rewriting.
elemwise
import
local_dimshuffle_lift
from
aesara.tensor.shape
import
specify_broadcastable
from
aesara.tensor.type
import
(
DenseTensorType
,
...
...
aesara/tensor/rewriting/__init__.py
浏览文件 @
63f52536
import
aesara.tensor.rewriting.basic
import
aesara.tensor.rewriting.elemwise
import
aesara.tensor.rewriting.extra_ops
import
aesara.tensor.rewriting.math
import
aesara.tensor.rewriting.shape
import
aesara.tensor.rewriting.subtensor
import
aesara.tensor.rewriting.uncanonicalize
aesara/tensor/rewriting/basic.py
浏览文件 @
63f52536
""" Tensor optimizations addressing the ops in basic.py."""
import
logging
import
sys
import
time
import
traceback
from
collections
import
defaultdict
from
io
import
StringIO
from
typing
import
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
import
aesara
import
aesara.scalar.basic
as
aes
from
aesara
import
compile
from
aesara.compile.ops
import
ViewOp
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
(
Apply
,
Constant
,
Variable
,
ancestors
,
equal_computations
,
io_toposort
,
)
from
aesara.graph.features
import
AlreadyThere
,
Feature
,
ReplaceValidate
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
compute_test_value
,
get_test_value
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.rewriting.basic
import
(
GraphRewriter
,
NodeRewriter
,
RemovalNodeRewriter
,
Rewriter
,
check_chain
,
copy_stack_trace
,
in2out
,
node_rewriter
,
)
from
aesara.graph.rewriting.db
import
RewriteDatabase
,
SequenceDB
from
aesara.graph.utils
import
(
InconsistencyError
,
MethodNotDefined
,
TestValueError
,
get_variable_trace_string
,
)
from
aesara.printing
import
Printer
,
pprint
,
set_precedence
from
aesara.graph.rewriting.db
import
RewriteDatabase
from
aesara.raise_op
import
Assert
,
CheckAndRaise
,
assert_op
from
aesara.tensor.basic
import
(
Alloc
,
...
...
@@ -56,53 +29,31 @@ from aesara.tensor.basic import (
alloc
,
as_tensor_variable
,
cast
,
constant
,
extract_constant
,
fill
,
get_scalar_constant_value
,
join
,
ones_like
,
stack
,
switch
,
tensor_copy
,
zeros
,
zeros_like
,
)
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
,
ShapeError
from
aesara.tensor.extra_ops
import
(
BroadcastTo
,
Repeat
,
Unique
,
broadcast_shape
,
broadcast_to
,
)
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.extra_ops
import
broadcast_shape
,
broadcast_to
from
aesara.tensor.math
import
all
as
at_all
from
aesara.tensor.math
import
eq
from
aesara.tensor.shape
import
(
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
,
shape_i
,
shape_padleft
,
specify_shape
,
unbroadcast
,
)
from
aesara.tensor.shape
import
Shape_i
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.type
import
(
DenseTensorType
,
TensorType
,
discrete_dtypes
,
integer_dtypes
,
)
from
aesara.tensor.type_other
import
NoneConst
from
aesara.tensor.type
import
DenseTensorType
,
TensorType
from
aesara.tensor.var
import
TensorConstant
from
aesara.utils
import
NoDuplicateOptWarningFilter
if
TYPE_CHECKING
:
from
aesara.tensor.rewriting.shape
import
ShapeFeature
_logger
=
logging
.
getLogger
(
"aesara.tensor.rewriting.basic"
)
_logger
.
addFilter
(
NoDuplicateOptWarningFilter
())
...
...
@@ -164,320 +115,6 @@ def broadcast_like(value, template, fgraph, dtype=None):
return
rval
class
InplaceElemwiseOptimizer
(
GraphRewriter
):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
def
__init__
(
self
,
OP
):
self
.
op
=
OP
def
add_requirements
(
self
,
fgraph
):
from
aesara.graph.destroyhandler
import
DestroyHandler
fgraph
.
attach_feature
(
DestroyHandler
())
@classmethod
def
print_profile
(
cls
,
stream
,
prof
,
level
=
0
):
blanc
=
" "
*
level
print
(
blanc
,
cls
.
__name__
,
prof
[
"opt"
]
.
op
,
file
=
stream
)
for
k
in
[
"node_before"
,
"nb_call_replace"
,
"nb_call_validate"
,
"nb_inconsistent"
,
]:
print
(
blanc
,
k
,
prof
[
k
],
file
=
stream
)
ndim
=
prof
[
"ndim"
]
if
ndim
:
print
(
blanc
,
"ndim"
,
"nb"
,
file
=
stream
)
for
n
in
sorted
(
ndim
.
keys
()):
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
def
apply
(
self
,
fgraph
):
r"""
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Examples
--------
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
# We should not validate too often as this takes too much time to
# execute!
# It is the _dfs_toposort() fct in aesara/graph/destroyhandler.py
# that takes so much time.
# Should we try to use another lib that does toposort?
# igraph: http://igraph.sourceforge.net/
# networkx: https://networkx.lanl.gov/
# Should we try to use cython?
# Compiling only that fct is not enough, should we try to add the
# deque class too?
# And init the deque and other list to an upper bound number of
# elements?
# Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic
#
# The next longest rewriter is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
# We execute `validate` after this number of change.
prof
=
{
"opt"
:
self
,
"node_before"
:
len
(
fgraph
.
apply_nodes
),
"nb_call_replace"
:
0
,
"nb_call_validate"
:
0
,
"nb_inconsistent"
:
0
,
"ndim"
:
defaultdict
(
lambda
:
0
),
}
check_each_change
=
config
.
tensor__insert_inplace_optimizer_validate_nb
if
check_each_change
==
-
1
:
if
len
(
fgraph
.
apply_nodes
)
>
500
:
check_each_change
=
10
else
:
check_each_change
=
1
nb_change_no_validate
=
0
chk
=
fgraph
.
checkpoint
()
if
fgraph
.
update_mapping
:
update_outs
=
[
fgraph
.
outputs
[
i
]
for
i
in
fgraph
.
update_mapping
]
else
:
update_outs
=
[]
protected_inputs
=
[
f
.
protected
for
f
in
fgraph
.
_features
if
isinstance
(
f
,
aesara
.
compile
.
function
.
types
.
Supervisor
)
]
protected_inputs
=
sum
(
protected_inputs
,
[])
# flatten the list
protected_inputs
.
extend
(
fgraph
.
outputs
)
for
node
in
list
(
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
)):
op
=
node
.
op
if
not
isinstance
(
op
,
self
.
op
):
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
if
(
check_each_change
!=
1
and
# If multiple outputs, they must all have the same size,
# so only check the first.
getattr
(
node
.
outputs
[
0
]
.
type
,
"ndim"
,
-
1
)
==
0
):
continue
if
op
.
inplace_pattern
:
# Maybe this isn't needed anymore, but I don't want to
# rish regression now. This case only happen if the
# original node add already some inplace patter and we
# still try to add more pattern.
baseline
=
op
.
inplace_pattern
candidate_outputs
=
[
i
for
i
in
range
(
len
(
node
.
outputs
))
if
i
not
in
baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
# inplace target.
# Remove here as faster.
candidate_inputs
=
[
i
for
i
in
range
(
len
(
node
.
inputs
))
if
i
not
in
baseline
.
values
()
and
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
# the next line should not be costly most of the time.
not
fgraph
.
has_destroyers
([
node
.
inputs
[
i
]])
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
else
:
baseline
=
[]
candidate_outputs
=
list
(
range
(
len
(
node
.
outputs
)))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs
=
[
i
for
i
in
range
(
len
(
node
.
inputs
))
if
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
not
fgraph
.
has_destroyers
([
node
.
inputs
[
i
]])
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
verbose
=
False
raised_warning
=
not
verbose
for
candidate_output
in
candidate_outputs
:
# If the output of the node can be established as an update
# output of the fgraph, visit the candidate_inputs in an order
# that will improve the chances of making the node operate
# inplace on the input it's meant to update
candidate_out_var
=
node
.
outputs
[
candidate_output
]
sorted_candidate_inputs
=
candidate_inputs
if
candidate_out_var
in
update_outs
:
# The candidate output is an update. Sort the
# variables in candidate_inputs in the following order:
# - Vars corresponding to the actual updated input
# (best case scenario is for the node that procudes
# an update to operate inplace on the variable to
# update)
# - Vars computed inplace on the updates input (second
# best scenario if for the node to work inplace on
# a variable obtained by a chain of inplace on the
# variable to update. In some cases, this will be
# equivalent to operating inplace on the variable to
# update)
# - Remaining variables
updated_inputs
=
[]
for
i
,
f_out
in
enumerate
(
fgraph
.
outputs
):
if
f_out
is
candidate_out_var
and
i
in
fgraph
.
update_mapping
:
updated_inp_idx
=
fgraph
.
update_mapping
[
i
]
updated_inputs
.
append
(
fgraph
.
inputs
[
updated_inp_idx
])
updated_vars
=
[]
vars_from_inplace
=
[]
other_vars
=
[]
for
inp_idx
in
candidate_inputs
:
inp
=
node
.
inputs
[
inp_idx
]
if
inp
in
updated_inputs
:
# the candidate input is the actual updated input
updated_vars
.
append
(
inp_idx
)
elif
(
hasattr
(
fgraph
,
"destroy_handler"
)
and
inp
.
owner
and
any
(
fgraph
.
destroy_handler
.
root_destroyer
.
get
(
up_inp
,
None
)
is
inp
.
owner
for
up_inp
in
updated_inputs
)
):
# the candidate input is a variable computed
# inplace on the updated input via a sequence of
# one or more inplace operations
vars_from_inplace
.
append
(
inp_idx
)
else
:
other_vars
.
append
(
inp_idx
)
sorted_candidate_inputs
=
(
updated_vars
+
vars_from_inplace
+
other_vars
)
for
candidate_input
in
sorted_candidate_inputs
:
# remove inputs that don't have the same dtype as the output
if
(
node
.
inputs
[
candidate_input
]
.
type
!=
node
.
outputs
[
candidate_output
]
.
type
):
continue
inplace_pattern
=
dict
(
baseline
)
inplace_pattern
[
candidate_output
]
=
candidate_input
try
:
if
hasattr
(
op
.
scalar_op
,
"make_new_inplace"
):
new_scal
=
op
.
scalar_op
.
make_new_inplace
(
aes
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
o
.
dtype
)
for
i
,
o
in
enumerate
(
node
.
outputs
)
]
)
)
else
:
new_scal
=
op
.
scalar_op
.
__class__
(
aes
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
range
(
len
(
node
.
outputs
))
]
)
)
new_outputs
=
self
.
op
(
new_scal
,
inplace_pattern
)(
*
node
.
inputs
,
return_list
=
True
)
new_node
=
new_outputs
[
0
]
.
owner
for
r
,
new_r
in
zip
(
node
.
outputs
,
new_outputs
):
prof
[
"nb_call_replace"
]
+=
1
fgraph
.
replace
(
r
,
new_r
,
reason
=
"inplace_elemwise_optimizer"
)
nb_change_no_validate
+=
1
prof
[
"ndim"
][
candidate_out_var
.
ndim
]
+=
1
if
nb_change_no_validate
>=
check_each_change
:
prof
[
"nb_call_validate"
]
+=
1
fgraph
.
validate
()
chk
=
fgraph
.
checkpoint
()
nb_change_no_validate
=
0
except
(
ValueError
,
InconsistencyError
)
as
e
:
prof
[
"nb_inconsistent"
]
+=
1
if
check_each_change
!=
1
and
not
raised_warning
:
print
(
(
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file
=
sys
.
stderr
,
)
print
(
e
,
file
=
sys
.
stderr
)
raised_warning
=
True
fgraph
.
revert
(
chk
)
continue
candidate_inputs
.
remove
(
candidate_input
)
node
=
new_node
baseline
=
inplace_pattern
break
if
nb_change_no_validate
>
0
:
try
:
fgraph
.
validate
()
except
Exception
:
if
not
raised_warning
:
print
(
(
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file
=
sys
.
stderr
,
)
fgraph
.
revert
(
chk
)
return
prof
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
f
"{' ' * level}{self.__class__.__name__} ({self.op})"
,
file
=
stream
,
)
return
inplace_elemwise_optimizer
inplace_elemwise_optimizer
=
InplaceElemwiseOptimizer
(
Elemwise
)
compile
.
optdb
.
register
(
"inplace_elemwise_opt"
,
inplace_elemwise_optimizer
,
"inplace_opt"
,
# for historic reason
"inplace_elemwise_optimizer"
,
"fast_run"
,
"inplace"
,
position
=
75
,
)
def
register_useless
(
node_rewriter
:
Union
[
RewriteDatabase
,
NodeRewriter
,
str
],
*
tags
,
**
kwargs
):
...
...
@@ -585,159 +222,6 @@ def register_specialize_device(
return
node_rewriter
def
apply_local_dimshuffle_lift
(
fgraph
,
var
):
"""
lift recursively
"""
if
not
var
.
owner
:
return
var
new
=
local_dimshuffle_lift
.
transform
(
fgraph
,
var
.
owner
)
if
new
:
return
new
[
0
]
return
var
def
is_dimshuffle_useless
(
new_order
,
input
):
"""
Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order.
2 - dimshuffle a broadcastable dimension.
"""
is_useless
=
True
if
len
(
new_order
)
==
input
.
type
.
ndim
:
all_broadcastable_dims
=
[
i
for
(
i
,
is_broadcastable
)
in
enumerate
(
input
.
type
.
broadcastable
)
if
is_broadcastable
]
+
[
"x"
]
for
i
in
range
(
input
.
type
.
ndim
):
if
new_order
[
i
]
==
i
or
(
i
in
all_broadcastable_dims
and
new_order
[
i
]
in
all_broadcastable_dims
):
is_useless
=
True
else
:
is_useless
=
False
break
else
:
is_useless
=
False
return
is_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
DimShuffle
])
def
local_dimshuffle_lift
(
fgraph
,
node
):
"""
"Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
"""
op
=
node
.
op
if
not
isinstance
(
op
,
DimShuffle
):
return
False
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
new_order
=
op
.
new_order
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
(
len
(
fgraph
.
clients
[
inp
])
==
1
):
# Don't use make_node to have tag.test_value set.
new_inputs
=
[]
for
inp
in
inode
.
inputs
:
new_inp
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
op
.
new_order
)(
inp
)
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
return
ret
if
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
new_order
=
[
x
==
"x"
and
"x"
or
inode
.
op
.
new_order
[
x
]
for
x
in
new_order
]
inp
=
inode
.
inputs
[
0
]
if
is_dimshuffle_useless
(
new_order
,
inp
):
return
[
inp
]
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
ret
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
new_order
)(
inp
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
@register_canonicalize
@register_specialize
@node_rewriter
([
DimShuffle
])
def
local_useless_dimshuffle_makevector
(
fgraph
,
node
):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
This rewrite is needed in order to clean up after
`local_subtensor_remove_broadcastable_index`, which produces a
not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)`
(i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`.
"""
# The `DimShuffle` should be removing the single broadcastable dimension
if
node
.
op
.
new_order
!=
():
return
makevector_out
=
node
.
inputs
[
0
]
if
(
not
makevector_out
.
owner
or
not
isinstance
(
makevector_out
.
owner
.
op
,
MakeVector
)
or
not
makevector_out
.
broadcastable
==
(
True
,)
):
return
assert
len
(
makevector_out
.
owner
.
inputs
)
==
1
return
[
makevector_out
.
owner
.
inputs
[
0
]]
@register_canonicalize
@node_rewriter
([
Reshape
])
def
local_useless_dimshuffle_in_reshape
(
fgraph
,
node
):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Reshape
):
return
False
if
not
(
node
.
inputs
[
0
]
.
owner
is
not
None
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
DimShuffle
)
):
return
False
new_order
=
node
.
inputs
[
0
]
.
owner
.
op
.
new_order
inp
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
broadcastables
=
node
.
inputs
[
0
]
.
broadcastable
new_order_of_nonbroadcast
=
[]
for
i
,
bd
in
zip
(
new_order
,
broadcastables
):
if
not
bd
:
new_order_of_nonbroadcast
.
append
(
i
)
no_change_in_order
=
all
(
new_order_of_nonbroadcast
[
i
]
<=
new_order_of_nonbroadcast
[
i
+
1
]
for
i
in
range
(
len
(
new_order_of_nonbroadcast
)
-
1
)
)
if
no_change_in_order
:
shape
=
node
.
inputs
[
1
]
ret
=
op
.
__class__
(
node
.
outputs
[
0
]
.
ndim
)(
inp
,
shape
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
@register_canonicalize
@register_specialize
@node_rewriter
([
TensorFromScalar
])
...
...
@@ -766,722 +250,6 @@ def local_scalar_tensor_scalar(fgraph, node):
return
[
s
]
class
MakeVectorPrinter
(
Printer
):
def
process
(
self
,
r
,
pstate
):
if
r
.
owner
is
None
:
raise
TypeError
(
"Can only print make_vector."
)
elif
isinstance
(
r
.
owner
.
op
,
MakeVector
):
with
set_precedence
(
pstate
):
s
=
[
pstate
.
pprinter
.
process
(
inp
)
for
inp
in
r
.
owner
.
inputs
]
return
f
"[{', '.join(s)}]"
else
:
raise
TypeError
(
"Can only print make_vector."
)
pprint
.
assign
(
MakeVector
,
MakeVectorPrinter
())
class
ShapeFeature
(
Feature
):
r"""A `Feature` that tracks shape information in a graph.
This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with
`Shape_i` and `MakeVector` `Op`\s.
This `Feature` and its associated rewrites have several goals:
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the
input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute
something just to know how big it will be. Firstly, it is a waste
of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many rewrites refuse to
work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is
present, or else using a conservative default. An Op that
supports shape-lifting should define a infer_shape(self, fgraph, node,
input_shapes) function. The argument input_shapes is a tuple of
tuples... there is an interior tuple for each input to the node.
The tuple has as many elements as dimensions. The element in
position i of tuple j represents the i'th shape component of the
j'th input. The function should return a tuple of tuples. One
output tuple for each node.output. Again, the i'th element of the
j'th output tuple represents the output[j].shape[i] of the
function. If an output is not a TensorType, then None should be
returned instead of a tuple for that output.
For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
for doing size-driven rewrites. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes.
In cases where you cannot figure out the shape, raise a ShapeError.
Notes
-----
Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape
inference is 1) to mark shared inputs as having a particular
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
certain dimensions).
We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in rewrites, use the
``shape_of`` dictionary.
For example:
.. code-block:: python
try:
shape_of = fgraph.shape_feature.shape_of
except AttributeError:
# This can happen when the mode doesn't include the ShapeFeature.
return
shape_of_output_zero = shape_of[node.output[0]]
The ``shape_of_output_zero`` symbol will contain a tuple, whose
elements are either integers or symbolic integers.
TODO: check to see if the symbols are necessarily
non-constant... or are integer literals sometimes Aesara
constants?? That would be confusing.
"""
def
get_node_infer_shape
(
self
,
node
):
try
:
shape_infer
=
node
.
op
.
infer_shape
except
AttributeError
:
shape_infer
=
self
.
default_infer_shape
try
:
o_shapes
=
shape_infer
(
self
.
fgraph
,
node
,
[
self
.
shape_of
[
r
]
for
r
in
node
.
inputs
]
)
except
ShapeError
:
o_shapes
=
self
.
default_infer_shape
(
self
.
fgraph
,
node
,
[
self
.
shape_of
[
r
]
for
r
in
node
.
inputs
]
)
except
NotImplementedError
as
e
:
raise
NotImplementedError
(
"Code called by infer_shape failed raising a "
"NotImplementedError. Raising NotImplementedError to "
"indicate that a shape cannot be computed is no longer "
"supported, and one should now use ShapeError "
f
"instead. The original exception message is: {e}"
)
.
with_traceback
(
e
.
__traceback__
)
except
Exception
as
e
:
msg
=
(
f
"Failed to infer_shape from Op {node.op}.
\n
Input shapes: "
f
"{[self.shape_of[r] for r in node.inputs]}
\n
Exception encountered during infer_shape: "
f
"{type(e)}
\n
Exception message: {str(e)}
\n
Traceback: {traceback.format_exc()}"
)
if
config
.
on_shape_error
==
"raise"
:
raise
Exception
(
msg
)
.
with_traceback
(
e
.
__traceback__
)
else
:
_logger
.
warning
(
msg
)
o_shapes
=
self
.
default_infer_shape
(
self
.
fgraph
,
node
,
[
self
.
shape_of
[
r
]
for
r
in
node
.
inputs
]
)
return
o_shapes
def
get_shape
(
self
,
var
,
idx
):
"""Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly ``shape_of[var][idx]``
as this method should update `shape_of` if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
r
=
self
.
shape_of
[
var
][
idx
]
if
(
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
Shape_i
)
and
r
.
owner
.
inputs
[
0
]
not
in
self
.
fgraph
.
variables
):
assert
var
.
owner
node
=
var
.
owner
# recur on inputs
for
i
in
node
.
inputs
:
if
getattr
(
i
.
type
,
"ndim"
,
None
)
>
0
:
self
.
get_shape
(
i
,
0
)
o_shapes
=
self
.
get_node_infer_shape
(
node
)
assert
len
(
o_shapes
)
==
len
(
node
.
outputs
)
# Only change the variables and dimensions that would introduce
# extra computation
for
new_shps
,
out
in
zip
(
o_shapes
,
node
.
outputs
):
if
not
hasattr
(
out
.
type
,
"ndim"
):
continue
merged_shps
=
list
(
self
.
shape_of
[
out
])
changed
=
False
for
i
in
range
(
out
.
type
.
ndim
):
n_r
=
merged_shps
[
i
]
if
(
n_r
.
owner
and
isinstance
(
n_r
.
owner
.
op
,
Shape_i
)
and
n_r
.
owner
.
inputs
[
0
]
not
in
self
.
fgraph
.
variables
):
changed
=
True
merged_shps
[
i
]
=
new_shps
[
i
]
if
changed
:
self
.
set_shape
(
out
,
merged_shps
,
override
=
True
)
r
=
self
.
shape_of
[
var
][
idx
]
return
r
def
shape_ir
(
self
,
i
,
r
):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
if
hasattr
(
r
.
type
,
"shape"
)
and
r
.
type
.
shape
[
i
]
is
not
None
:
return
constant
(
r
.
type
.
shape
[
i
],
dtype
=
"int64"
)
else
:
# Do not call make_node for test_value
s
=
Shape_i
(
i
)(
r
)
try
:
s
=
get_scalar_constant_value
(
s
)
except
NotScalarConstantError
:
pass
return
s
def
shape_tuple
(
self
,
r
):
"""Return a tuple of symbolic shape vars for tensor variable r."""
if
not
hasattr
(
r
.
type
,
"ndim"
):
# This happen for NoneConst.
return
None
return
tuple
(
self
.
shape_ir
(
i
,
r
)
for
i
in
range
(
r
.
type
.
ndim
))
def
default_infer_shape
(
self
,
fgraph
,
node
,
i_shapes
):
"""Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
"""
rval
=
[]
for
r
in
node
.
outputs
:
try
:
rval
.
append
(
self
.
shape_tuple
(
r
))
except
AttributeError
:
rval
.
append
(
None
)
return
rval
def
unpack
(
self
,
s_i
,
var
):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
var: the variable that correspond to s_i. This is just for
error reporting.
"""
assert
s_i
is
not
None
if
s_i
==
1
:
return
self
.
lscalar_one
if
isinstance
(
s_i
,
float
)
and
int
(
s_i
)
==
s_i
:
s_i
=
int
(
s_i
)
if
isinstance
(
s_i
,
(
np
.
integer
,
int
))
or
(
isinstance
(
s_i
,
np
.
ndarray
)
and
s_i
.
ndim
==
0
):
# this shape is a constant
if
s_i
<
0
:
msg
=
"There is a negative shape in the graph!"
msg
+=
get_variable_trace_string
(
var
)
# The rest of the pipeline don't handle correctly this
# case. So we have 2 choices, stop compilation or
# consider the shape as unknown. As we have more
# chance to give the stack trace here then later, I
# choose that options as it would give better error
# message.
raise
AssertionError
(
msg
)
return
constant
(
s_i
,
dtype
=
"int64"
)
if
isinstance
(
s_i
,
(
tuple
,
list
)):
# this dimension is the same as many of the inputs
# which tells us that if one of the inputs is known,
# the others all become known.
# TODO: should be implemented in Elemwise, and Dot
#
# worst case, we loop over shape_of and replace things
raise
NotImplementedError
(
s_i
)
# s_i is x.shape[i] for some x, we change it to shape_of[x][i]
if
(
s_i
.
owner
and
isinstance
(
s_i
.
owner
.
op
,
Subtensor
)
and
s_i
.
owner
.
inputs
[
0
]
.
owner
and
isinstance
(
s_i
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape
)
):
assert
s_i
.
type
.
ndim
==
0
assert
len
(
s_i
.
owner
.
op
.
idx_list
)
==
1
# The current Subtensor always put constant index in the graph.
# This was not True in the past. So call the Subtensor function
# that will return the right index.
idx
=
get_idx_list
(
s_i
.
owner
.
inputs
,
s_i
.
owner
.
op
.
idx_list
)
assert
len
(
idx
)
==
1
idx
=
idx
[
0
]
try
:
i
=
get_scalar_constant_value
(
idx
)
except
NotScalarConstantError
:
pass
else
:
# Executed only if no exception was raised
x
=
s_i
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
# x should already have been imported, and should be in shape_of.
s_i
=
self
.
shape_of
[
x
][
i
]
if
s_i
.
type
.
dtype
in
integer_dtypes
:
if
getattr
(
s_i
.
type
,
"ndim"
,
0
):
raise
TypeError
(
"Shape element must be scalar"
,
s_i
)
return
s_i
else
:
raise
TypeError
(
"Unsupported shape element"
,
s_i
,
type
(
s_i
),
getattr
(
s_i
,
"type"
,
None
)
)
def
set_shape
(
self
,
r
,
s
,
override
=
False
):
"""Assign the shape `s` to previously un-shaped variable `r`.
Parameters
----------
r : a variable
s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
"""
if
not
override
:
assert
r
not
in
self
.
shape_of
,
"r already in shape_of"
if
s
is
None
:
self
.
shape_of
[
r
]
=
s
else
:
if
not
isinstance
(
s
,
(
tuple
,
list
)):
raise
TypeError
(
"shapes must be tuple/list"
,
(
r
,
s
))
if
r
.
type
.
ndim
!=
len
(
s
):
sio
=
StringIO
()
aesara
.
printing
.
debugprint
(
r
,
file
=
sio
,
print_type
=
True
)
raise
AssertionError
(
f
"Something inferred a shape with {len(s)} dimensions "
f
"for a variable with {int(r.type.ndim)} dimensions"
f
" for the variable:
\n
{sio.getvalue()}"
)
shape_vars
=
[]
for
i
in
range
(
r
.
type
.
ndim
):
if
hasattr
(
r
.
type
,
"shape"
)
and
r
.
type
.
shape
[
i
]
is
not
None
:
shape_vars
.
append
(
constant
(
r
.
type
.
shape
[
i
],
dtype
=
"int64"
))
else
:
shape_vars
.
append
(
self
.
unpack
(
s
[
i
],
r
))
assert
all
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
i
]
or
self
.
lscalar_one
.
equals
(
shape_vars
[
i
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
shape_vars
[
i
]))
for
i
in
range
(
r
.
type
.
ndim
)
)
self
.
shape_of
[
r
]
=
tuple
(
shape_vars
)
for
sv
in
shape_vars
:
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
update_shape
(
self
,
r
,
other_r
):
"""Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions.
"""
# other_r should already have a shape
assert
other_r
in
self
.
shape_of
,
(
"other_r not in shape_of"
,
other_r
)
other_shape
=
self
.
shape_of
[
other_r
]
# If other_shape has no information, call is pointless.
if
other_shape
is
None
:
return
if
r
in
self
.
shape_of
:
r_shape
=
self
.
shape_of
[
r
]
else
:
# If no info is known on r's shape, use other_shape
self
.
set_shape
(
r
,
other_shape
)
return
if
(
other_r
.
owner
and
r
.
owner
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
op
==
r
.
owner
.
op
):
# We are doing a merge, so the two shape graphs will be the
# same. This is only done so that we call `ancestors` less
# frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape
=
[]
for
i
,
ps
in
enumerate
(
other_shape
):
if
r_shape
is
None
and
other_shape
:
merged_shape
.
append
(
other_shape
[
i
])
elif
(
ps
.
owner
and
isinstance
(
getattr
(
ps
.
owner
,
"op"
,
None
),
Shape_i
)
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
r_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
other_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
other_shape
[
i
])
elif
other_shape
[
i
]
==
r_shape
[
i
]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape
.
append
(
r_shape
[
i
])
elif
r_shape
[
i
]
in
ancestors
([
other_shape
[
i
]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
# is strictly more complex. Such a substitution could also lead
# to cycles: if (in the future) r_shape[i] gets replaced by an
# expression of other_shape[i], other_shape[i] may end up
# depending on itself.
merged_shape
.
append
(
r_shape
[
i
])
else
:
merged_shape
.
append
(
other_shape
[
i
])
assert
all
(
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
i
]
and
not
other_r
.
type
.
broadcastable
[
i
]
)
or
self
.
lscalar_one
.
equals
(
merged_shape
[
i
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
merged_shape
[
i
],
only_process_constants
=
True
)
)
for
i
in
range
(
r
.
type
.
ndim
)
)
self
.
shape_of
[
r
]
=
tuple
(
merged_shape
)
for
sv
in
self
.
shape_of
[
r
]:
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
set_shape_i
(
self
,
r
,
i
,
s_i
):
"""Replace element i of shape_of[r] by s_i"""
assert
r
in
self
.
shape_of
prev_shape
=
self
.
shape_of
[
r
]
# prev_shape is a tuple, so we cannot change it inplace,
# so we build another one.
new_shape
=
[]
for
j
,
s_j
in
enumerate
(
prev_shape
):
if
j
==
i
:
new_shape
.
append
(
self
.
unpack
(
s_i
,
r
))
else
:
new_shape
.
append
(
s_j
)
assert
all
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
idx
]
or
self
.
lscalar_one
.
equals
(
new_shape
[
idx
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
new_shape
[
idx
]))
for
idx
in
range
(
r
.
type
.
ndim
)
)
self
.
shape_of
[
r
]
=
tuple
(
new_shape
)
for
sv
in
self
.
shape_of
[
r
]:
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
init_r
(
self
,
r
):
"""Register r's shape in the shape_of dictionary."""
if
r
not
in
self
.
shape_of
:
self
.
set_shape
(
r
,
self
.
shape_tuple
(
r
))
def
make_vector_shape
(
self
,
r
):
return
as_tensor_variable
(
self
.
shape_of
[
r
],
ndim
=
1
,
dtype
=
"int64"
)
def
on_attach
(
self
,
fgraph
):
if
hasattr
(
fgraph
,
"shape_feature"
):
raise
AlreadyThere
(
"This FunctionGraph already has a ShapeFeature"
)
if
hasattr
(
self
,
"fgraph"
)
and
self
.
fgraph
!=
fgraph
:
raise
Exception
(
"This ShapeFeature is already attached to a graph"
)
self
.
fgraph
=
fgraph
fgraph
.
shape_feature
=
self
# Must be local to the object as otherwise we reuse the same
# variable for multiple fgraph!
self
.
lscalar_one
=
constant
(
1
,
dtype
=
"int64"
)
assert
self
.
lscalar_one
.
type
.
dtype
==
"int64"
self
.
fgraph
=
fgraph
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
shape_of
=
{}
# Variable ->
self
.
scheduled
=
{}
# shape var -> graph v
self
.
shape_of_reverse_index
=
{}
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
,
reason
=
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
self
.
shape_of
=
{}
self
.
scheduled
=
{}
self
.
shape_of_reverse_index
=
{}
self
.
fgraph
=
None
del
fgraph
.
shape_feature
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
shape_of
:
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
assert
r
in
self
.
shape_of
return
for
i
,
r
in
enumerate
(
node
.
inputs
):
# make sure we have shapes for the inputs
self
.
init_r
(
r
)
o_shapes
=
self
.
get_node_infer_shape
(
node
)
# this is packed information
# an element of o_shapes is either None or a tuple
# elements of the tuple can be either strings, or ints
if
len
(
o_shapes
)
!=
len
(
node
.
outputs
):
raise
Exception
(
(
f
'The infer_shape method for the Op "{node.op}" returned a list '
f
"with the wrong number of element: len(o_shapes) = {len(o_shapes)} "
f
" != len(node.outputs) = {len(node.outputs)}"
)
)
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` rewrite does not fail.
for
sh_idx
,
sh
in
enumerate
(
o_shapes
):
if
sh
is
None
:
continue
if
not
isinstance
(
sh
,
(
list
,
tuple
)):
raise
ValueError
(
f
"infer_shape of {node} didn't return a list of"
f
" list. It returned '{o_shapes}'"
)
new_shape
=
[]
for
i
,
d
in
enumerate
(
sh
):
# Note: we ignore any shape element that is not typed (i.e.,
# does not have a 'dtype' attribute). This means there may
# still remain int elements that are int32 on 32-bit platforms,
# but this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix.
if
getattr
(
d
,
"dtype"
,
"int64"
)
!=
"int64"
:
assert
d
.
dtype
in
discrete_dtypes
,
(
node
,
d
.
dtype
)
assert
str
(
d
.
dtype
)
!=
"uint64"
,
node
new_shape
+=
sh
[
len
(
new_shape
)
:
i
+
1
]
if
isinstance
(
d
,
Constant
):
casted_d
=
constant
(
d
.
data
,
dtype
=
"int64"
)
else
:
casted_d
=
cast
(
d
,
"int64"
)
new_shape
[
i
]
=
casted_d
if
new_shape
:
# We replace the shape with wrong dtype by the one with
# 'int64'.
new_shape
+=
sh
[
len
(
new_shape
)
:]
o_shapes
[
sh_idx
]
=
tuple
(
new_shape
)
for
r
,
s
in
zip
(
node
.
outputs
,
o_shapes
):
self
.
set_shape
(
r
,
s
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
if
new_r
not
in
self
.
shape_of
:
# It happen that the fgraph didn't called on_import for some
# new_r. This happen when new_r don't have an
# owner(i.e. it is a constant or an input of the graph)
# update_shape suppose that r and new_r are in shape_of.
self
.
init_r
(
new_r
)
# This tells us that r and new_r must have the same shape if
# we didn't know that the shapes are related, now we do.
self
.
update_shape
(
new_r
,
r
)
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
# In case 1, if r has a shape_i client, we will want to
# replace the shape_i of r with the shape of new_r. Say that
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
for
(
shpnode
,
idx
)
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
if
isinstance
(
getattr
(
shpnode
,
"op"
,
None
),
Shape_i
):
idx
=
shpnode
.
op
.
i
repl
=
self
.
shape_of
[
new_r
][
idx
]
if
repl
.
owner
is
shpnode
:
# This mean the replacement shape object is
# exactly the same as the current shape object. So
# no need for replacement.
continue
if
(
repl
.
owner
and
repl
.
owner
.
inputs
[
0
]
is
shpnode
.
inputs
[
0
]
and
isinstance
(
repl
.
owner
.
op
,
Shape_i
)
and
repl
.
owner
.
op
.
i
==
shpnode
.
op
.
i
):
# The replacement is a shape_i of the same
# input. So no need to do this equivalent
# replacement.
continue
if
shpnode
.
outputs
[
0
]
in
ancestors
([
repl
]):
raise
InconsistencyError
(
"This substitution would insert a cycle in the graph:"
f
"node: {node}, i: {i}, r: {r}, new_r: {new_r}"
)
self
.
scheduled
[
shpnode
]
=
new_r
# In case 2, if r is a variable that we've scheduled for shape update,
# then we should cancel it.
unscheduled
=
[
k
for
k
,
v
in
self
.
scheduled
.
items
()
if
v
==
r
]
for
k
in
unscheduled
:
del
self
.
scheduled
[
k
]
# In either case, r could be in shape_of.values(), that is, r itself
# is the shape of something. In that case, we want to update
# the value in shape_of, to keep it up-to-date.
for
v
in
self
.
shape_of_reverse_index
.
get
(
r
,
[]):
# The reverse index is only approximate. It is not updated on
# deletion of variables, or on change_input so it might be the
# case that there are a few extra `v`'s in it that no longer have
# a shape of r or possibly have been deleted from shape_of
# entirely. The important thing is that it permits to recall
# all variables with r in their shape.
for
ii
,
svi
in
enumerate
(
self
.
shape_of
.
get
(
v
,
[])):
if
svi
==
r
:
self
.
set_shape_i
(
v
,
ii
,
new_r
)
self
.
shape_of_reverse_index
[
r
]
=
set
()
def
same_shape
(
self
,
x
:
Variable
,
y
:
Variable
,
dim_x
:
Optional
[
int
]
=
None
,
dim_y
:
Optional
[
int
]
=
None
,
)
->
bool
:
"""Return ``True`` if `x` and `y` have the same shape.
Parameters
==========
x
The `Variable` for which its shape is to be compared with `y`'s shape.
y
The `Variable` for which its shape is to be compared with `x`'s shape.
dim_x
If non ``None``, compare only the dimension of `x` equal to
`dim_x`.
dim_y
If non ``None``, compare only the dimension of `y` equal to
`dim_y`.
"""
sx
=
self
.
shape_of
[
x
]
sy
=
self
.
shape_of
[
y
]
if
sx
is
None
or
sy
is
None
:
return
False
if
dim_x
is
not
None
:
sx
=
[
sx
[
dim_x
]]
if
dim_y
is
not
None
:
sy
=
[
sy
[
dim_y
]]
if
len
(
sx
)
!=
len
(
sy
):
return
False
# Canonicalize the graphs so that comparisons are reasonable
# TODO FIXME: This should *not* need to be performed manually here.
# Instead, the shape information in `self.shape_of` should be operated
# upon alongside all the other elements in a `FunctionGraph` (e.g. as
# if `self.shape_of.values()` were additional outputs).
shapes_fg
=
FunctionGraph
(
outputs
=
sx
+
sy
,
# features=[self],
clone
=
True
,
# copy_inputs=False,
)
from
aesara.graph.rewriting.utils
import
rewrite_graph
canon_shapes
=
rewrite_graph
(
shapes_fg
,
custom_rewrite
=
topo_constant_folding
)
.
outputs
sx
=
canon_shapes
[:
len
(
sx
)]
sy
=
canon_shapes
[
len
(
sx
)
:]
for
dx
,
dy
in
zip
(
sx
,
sy
):
if
not
equal_computations
([
dx
],
[
dy
]):
return
False
return
True
def
clone
(
self
):
return
type
(
self
)()
class
ShapeOptimizer
(
GraphRewriter
):
"""Rewriter that adds `ShapeFeature` as a feature."""
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ShapeFeature
())
def
apply
(
self
,
fgraph
):
pass
class
UnShapeOptimizer
(
GraphRewriter
):
"""Rewriter that removes `ShapeFeature` as a feature."""
def
apply
(
self
,
fgraph
):
for
feature
in
fgraph
.
_features
:
if
isinstance
(
feature
,
ShapeFeature
):
fgraph
.
remove_feature
(
feature
)
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
aesara
.
compile
.
mode
.
optdb
.
register
(
"ShapeOpt"
,
ShapeOptimizer
(),
"fast_run"
,
"fast_compile"
,
position
=
0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
aesara
.
compile
.
mode
.
optdb
.
register
(
"UnShapeOpt"
,
UnShapeOptimizer
(),
position
=
10
)
@register_specialize
(
"local_alloc_elemwise"
)
@node_rewriter
([
Elemwise
])
def
local_elemwise_alloc
(
fgraph
,
node
):
...
...
@@ -1815,43 +583,6 @@ compile.optdb.register(
)
@register_specialize
@register_canonicalize
@node_rewriter
([
Shape
])
def
local_shape_to_shape_i
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
Shape
):
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
shape_feature
=
fgraph
.
shape_feature
ret
=
shape_feature
.
make_vector_shape
(
node
.
inputs
[
0
])
# We need to copy over stack trace from input to output
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
@register_specialize
@register_canonicalize
@node_rewriter
([
Shape_i
])
def
local_track_shape_i
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
Shape_i
):
return
False
try
:
shape_feature
=
fgraph
.
shape_feature
except
AttributeError
:
return
False
if
node
not
in
shape_feature
.
scheduled
:
return
False
# Don't unschedule node as it could be reinserted in the
# fgraph as we don't change it in the shapefeature internal
# structure.
replacement
=
shape_feature
.
scheduled
[
node
]
return
[
shape_feature
.
shape_of
[
replacement
][
node
.
op
.
i
]]
@register_useless
@register_canonicalize
(
"fast_compile"
)
@register_specialize
...
...
@@ -2130,153 +861,6 @@ compile.optdb["useless"].register(
)
@register_canonicalize
@node_rewriter
([
Elemwise
])
def
local_upcast_elemwise_constant_inputs
(
fgraph
,
node
):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
if
len
(
node
.
outputs
)
>
1
:
return
try
:
shape_i
=
fgraph
.
shape_feature
.
shape_i
except
AttributeError
:
shape_i
=
None
if
isinstance
(
node
.
op
,
Elemwise
):
scalar_op
=
node
.
op
.
scalar_op
# print "aa", scalar_op.output_types_preference
if
getattr
(
scalar_op
,
"output_types_preference"
,
None
)
in
(
aes
.
upgrade_to_float
,
aes
.
upcast_out
,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
new_inputs
=
[]
for
i
in
node
.
inputs
:
if
i
.
type
.
dtype
==
output_dtype
:
new_inputs
.
append
(
i
)
else
:
try
:
# works only for scalars
cval_i
=
get_scalar_constant_value
(
i
,
only_process_constants
=
True
)
if
all
(
i
.
broadcastable
):
new_inputs
.
append
(
shape_padleft
(
cast
(
cval_i
,
output_dtype
),
i
.
ndim
)
)
else
:
if
shape_i
is
None
:
return
new_inputs
.
append
(
alloc
(
cast
(
cval_i
,
output_dtype
),
*
[
shape_i
(
d
)(
i
)
for
d
in
range
(
i
.
ndim
)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except
NotScalarConstantError
:
# for the case of a non-scalar
if
isinstance
(
i
,
TensorConstant
):
new_inputs
.
append
(
cast
(
i
,
output_dtype
))
else
:
new_inputs
.
append
(
i
)
if
new_inputs
!=
node
.
inputs
:
rval
=
[
node
.
op
(
*
new_inputs
)]
if
not
node
.
outputs
[
0
]
.
type
.
is_super
(
rval
[
0
]
.
type
):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
# Copy over output stacktrace from before upcasting
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
return
rval
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
Unbroadcast
])
def
local_useless_unbroadcast
(
fgraph
,
node
):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
TODO: Implement equivalent rewrite for SpecifyShape
"""
if
isinstance
(
node
.
op
,
Unbroadcast
):
x
=
node
.
inputs
[
0
]
if
x
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return
[
x
]
else
:
# Keep the flags that modify something
new_axes
=
tuple
(
ax
for
ax
in
node
.
op
.
axes
if
x
.
type
.
shape
[
ax
]
==
1
)
if
new_axes
==
node
.
op
.
axes
:
# All flags are useful
return
None
else
:
r
=
unbroadcast
(
x
,
*
new_axes
)
# Copy over stacktrace from previous output
copy_stack_trace
(
node
.
outputs
,
r
)
return
[
r
]
@register_canonicalize
@register_specialize
@node_rewriter
([
Unbroadcast
])
def
local_unbroadcast_lift
(
fgraph
,
node
):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Unbroadcast
):
return
False
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
len
(
inode
.
inputs
)
==
1
:
if
len
(
fgraph
.
clients
.
get
(
inp
,
()))
==
1
:
unbroadcasted
=
unbroadcast
(
inode
.
inputs
[
0
],
*
op
.
axes
)
copy_stack_trace
(
node
.
outputs
,
unbroadcasted
)
rval
=
inode
.
op
.
make_node
(
unbroadcasted
)
.
outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
if
inode
and
isinstance
(
inode
.
op
,
Unbroadcast
):
# Merge axis of each unbroadcast
axis
=
tuple
(
set
(
inode
.
op
.
axes
)
.
union
(
set
(
op
.
axes
)))
iinput
=
inode
.
inputs
[
0
]
rval
=
[
unbroadcast
(
iinput
,
*
axis
)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
@register_specialize
@register_canonicalize
@register_useless
...
...
@@ -2412,7 +996,7 @@ def local_useless_switch(fgraph, node):
if
not
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Switch
):
return
False
shape_feature
:
Optional
[
ShapeFeature
]
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
shape_feature
:
Optional
[
"ShapeFeature"
]
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
if
shape_feature
is
None
:
return
False
...
...
@@ -2537,225 +1121,6 @@ def local_useless_split(fgraph, node):
return
[
out2
]
def
local_reshape_chain
(
op
):
@node_rewriter
([
op
])
def
f
(
fgraph
,
node
):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if
not
check_chain
(
node
,
op
,
op
):
return
False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
],
node
.
inputs
[
1
])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace
(
node
.
outputs
,
rval
)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if
rval
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
return
[
rval
]
else
:
return
False
return
f
register_canonicalize
(
local_reshape_chain
(
Reshape
),
name
=
"local_reshape_chain"
)
@register_useless
@register_canonicalize
@register_stabilize
@node_rewriter
([
Reshape
])
def
local_useless_reshape
(
fgraph
,
node
):
"""
Remove two kinds of useless reshape.
Remove Reshape when both the input and output have a single dimension.
Remove Reshape when reshaping to the shape of the input.
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Reshape
):
return
False
inp
=
node
.
inputs
[
0
]
output
=
node
.
outputs
[
0
]
output_shape
=
node
.
inputs
[
1
]
if
inp
.
ndim
!=
output
.
ndim
:
return
False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if
inp
.
ndim
==
1
and
output
.
ndim
==
1
and
inp
.
broadcastable
==
output
.
broadcastable
:
return
[
inp
]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if
output_shape
.
owner
and
isinstance
(
output_shape
.
owner
.
op
,
Shape
):
shape_input
=
output_shape
.
owner
.
inputs
[
0
]
if
shape_input
==
inp
:
return
[
inp
]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if
output_shape
.
owner
and
isinstance
(
output_shape
.
owner
.
op
,
MakeVector
):
output_shape_is
=
output_shape
.
owner
.
inputs
shape_feature
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
nb_m1
=
0
shape_match
=
[
False
]
*
inp
.
ndim
for
dim
in
range
(
inp
.
ndim
):
outshp_i
=
output_shape_is
[
dim
]
# Match Shape_i{dim}(input)
if
(
outshp_i
.
owner
and
isinstance
(
outshp_i
.
owner
.
op
,
Shape_i
)
and
outshp_i
.
owner
.
op
.
i
==
dim
and
outshp_i
.
owner
.
inputs
[
0
]
==
inp
):
shape_match
[
dim
]
=
True
continue
# Match Shape(input)[dim]
if
(
outshp_i
.
owner
and
isinstance
(
outshp_i
.
owner
.
op
,
Subtensor
)
and
len
(
outshp_i
.
owner
.
inputs
)
==
2
and
extract_constant
(
outshp_i
.
owner
.
inputs
[
1
])
==
dim
):
subtensor_inp
=
outshp_i
.
owner
.
inputs
[
0
]
if
subtensor_inp
.
owner
and
isinstance
(
subtensor_inp
.
owner
.
op
,
Shape
):
shape_input_i
=
subtensor_inp
.
owner
.
inputs
[
0
]
if
shape_input_i
==
inp
:
shape_match
[
dim
]
=
True
continue
# Match 1 if input.broadcastable[dim] is True
cst_outshp_i
=
extract_constant
(
outshp_i
,
only_process_constants
=
1
)
if
inp
.
broadcastable
[
dim
]
and
cst_outshp_i
==
1
:
shape_match
[
dim
]
=
True
continue
# Match -1
if
cst_outshp_i
==
-
1
:
shape_match
[
dim
]
=
True
nb_m1
+=
1
continue
# Match shape_of[input][dim] or its constant equivalent
if
shape_feature
:
inpshp_i
=
shape_feature
.
get_shape
(
inp
,
dim
)
if
inpshp_i
==
outshp_i
or
(
extract_constant
(
inpshp_i
,
only_process_constants
=
1
)
==
extract_constant
(
outshp_i
,
only_process_constants
=
1
)
):
shape_match
[
dim
]
=
True
continue
if
all
(
shape_match
)
and
nb_m1
<=
1
:
return
[
inp
]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return
False
@register_canonicalize
@node_rewriter
([
Reshape
])
def
local_reshape_to_dimshuffle
(
fgraph
,
node
):
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
The goal is to avoid using reshape to add or remove broadcastable
dimensions, but use dimshuffle instead, so dimshuffles can cancel out
or be removed later on.
For example:
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Reshape
):
return
False
inp
=
node
.
inputs
[
0
]
output
=
node
.
outputs
[
0
]
output_shape
=
node
.
inputs
[
1
]
dimshuffle_new_order
=
[]
new_output_shape
=
[]
index
=
0
# index over the output of the new reshape
for
i
in
range
(
output
.
ndim
):
# Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim
=
extract_constant
(
output_shape
[
i
],
only_process_constants
=
False
,
elemwise
=
False
)
if
dim
==
1
:
dimshuffle_new_order
.
append
(
"x"
)
else
:
dimshuffle_new_order
.
append
(
index
)
new_output_shape
.
append
(
dim
)
index
=
index
+
1
if
index
!=
output
.
ndim
:
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
copy_stack_trace
(
output
,
inner
)
new_node
=
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inner
)]
copy_stack_trace
(
output
,
new_node
)
return
new_node
@register_canonicalize
@register_stabilize
@node_rewriter
([
Reshape
])
def
local_reshape_lift
(
fgraph
,
node
):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Notes
-----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
"""
if
(
isinstance
(
node
.
op
,
Reshape
)
and
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Elemwise
)
and
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
==
1
):
r
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
],
node
.
inputs
[
1
])
# Copy stacktrace from previous Reshape op, as an error in new
# Reshape op could only have been caused by old one.
copy_stack_trace
(
node
.
outputs
,
r
)
e
=
node
.
inputs
[
0
]
.
owner
.
op
(
r
)
# Copy stacktrace from both previous Reshape and UnaryElemwise op
# because an error in new cg could have been caused by either ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
e
)
return
[
e
]
register_canonicalize
(
RemovalNodeRewriter
(
tensor_copy
),
name
=
"remove_tensor_copy"
)
@node_rewriter
(
None
)
def
constant_folding
(
fgraph
,
node
):
...
...
@@ -2817,431 +1182,6 @@ register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True)
register_specialize
(
topo_constant_folding
,
"fast_compile"
,
final_rewriter
=
True
)
def
local_elemwise_fusion_op
(
op_class
,
max_input_fct
=
lambda
node
:
32
,
maker
=
None
):
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` `Op`\s.
Parameters
----------
op_class : type
`Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature ``(node, *args)`` that constructs an
`op_class` instance (e.g. ``op_class(*args)``).
"""
if
maker
is
None
:
def
maker
(
node
,
scalar_op
):
return
op_class
(
scalar_op
)
def
local_fuse
(
fgraph
,
node
):
r"""Fuse `Elemwise` `Op`\s in a node.
As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the
same shape.
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by Aesara itself.
"""
# TODO: use broadcast flag?
# TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
# TODO: use malloc and copy to transfer arguments that don't
# fit within the parameter space of 256 bytes
#
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
# TODO: Use Composite to combine Elemwise and Reduce
# operations. We have to loop over the data anyway... might
# as well sum it up while we're at it (this can be trickier
# than i'm making it seound here. The data-traversal should be
# done contiguously, and the summing-up might not be easy or
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if
type
(
node
.
op
)
is
not
op_class
:
return
False
if
len
(
node
.
outputs
)
>
1
:
# We don't support fusion for nodes with multiple outputs.
return
inputs
=
[]
# inputs of the new Elemwise op.
s_inputs
=
[]
# inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
s_g
=
[]
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input
=
max_input_fct
(
node
)
# The number of inputs to the new fused op if we do not fuse more
# inputs.
new_nb_input
=
len
(
node
.
inputs
)
# Did we fuse something?
# Needed as we can fuse unary op that don't change the number of
# inputs.
# And there is a case where the inputs are the same as the current
# node. That won't change the number of inputs of the new op.
fused
=
False
for
i
in
node
.
inputs
:
scalar_node
:
Optional
[
Apply
]
=
None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input
=
[]
# Same as tmp_input, but for scalars.
tmp_scalar
=
[]
# We should not check the number of inputs here
# As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
op_class
)
and
len
({
n
for
n
,
idx
in
fgraph
.
clients
[
i
]})
==
1
and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i
.
owner
.
outputs
[
0
]
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
):
try
:
tmp_s_input
=
[]
# we should not put duplicate input into s_inputs and inputs
for
ii
in
i
.
owner
.
inputs
:
if
ii
in
inputs
:
tmp_s_input
.
append
(
s_inputs
[
inputs
.
index
(
ii
)])
elif
ii
in
tmp_input
:
tmp_s_input
.
append
(
tmp_scalar
[
tmp_input
.
index
(
ii
)])
else
:
tmp
=
aes
.
get_scalar_type
(
ii
.
type
.
dtype
)
.
make_variable
()
try
:
tv
=
get_test_value
(
ii
)
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp
.
tag
.
test_value
=
tv
.
item
()
except
(
TestValueError
,
ValueError
):
pass
tmp_s_input
.
append
(
tmp
)
tmp_input
.
append
(
ii
)
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node
=
i
.
owner
.
op
.
scalar_op
.
make_node
(
*
tmp_s_input
)
if
config
.
compute_test_value_opt
!=
"off"
:
# This is required because `Op.make_node` won't do it
compute_test_value
(
scalar_node
)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i
.
owner
.
op
.
scalar_op
.
c_code
(
scalar_node
,
"test_presence_of_c_code"
,
[
"x"
for
x
in
i
.
owner
.
inputs
],
[
"z"
for
z
in
i
.
owner
.
outputs
],
{
"fail"
:
"
%(fail)
s"
},
)
except
(
NotImplementedError
,
MethodNotDefined
):
_logger
.
warning
(
(
"Rewrite warning: "
f
"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
scalar_node
=
None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_
=
new_nb_input
+
len
(
tmp_input
)
-
1
# If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for
x
in
tmp_input
:
if
x
in
node
.
inputs
:
new_nb_input_
-=
1
if
scalar_node
and
(
new_nb_input_
<=
max_nb_input
):
fused
=
True
new_nb_input
=
new_nb_input_
inputs
.
extend
(
tmp_input
)
s_inputs
.
extend
(
tmp_scalar
)
s_g
.
extend
(
scalar_node
.
outputs
)
else
:
# We must support the case where the same variable appears many
# times within the inputs
if
inputs
.
count
(
i
)
==
node
.
inputs
.
count
(
i
):
s
=
s_inputs
[
inputs
.
index
(
i
)]
else
:
s
=
aes
.
get_scalar_type
(
i
.
type
.
dtype
)
.
make_variable
()
if
config
.
compute_test_value_opt
!=
"off"
:
try
:
v
=
get_test_value
(
i
)
# See the zero-dimensional test value situation
# described above.
s
.
tag
.
test_value
=
v
.
item
()
except
(
TestValueError
,
ValueError
):
pass
inputs
.
append
(
i
)
s_inputs
.
append
(
s
)
s_g
.
append
(
s
)
if
not
fused
:
return
False
if
new_nb_input
!=
len
(
inputs
)
or
len
(
s_inputs
)
!=
len
(
inputs
):
# TODO FIXME: This shouldn't be a generic `Exception`
raise
Exception
(
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out
=
node
.
op
.
scalar_op
(
*
s_g
,
return_list
=
True
)
try
:
s_new_out
[
0
]
.
owner
.
op
.
c_code
(
s_new_out
[
0
]
.
owner
,
"test_presence_of_c_code"
,
[
"x"
for
x
in
s_g
],
[
"z"
for
x
in
s_new_out
],
{
"fail"
:
"
%(fail)
s"
},
)
except
(
NotImplementedError
,
MethodNotDefined
):
name
=
str
(
s_new_out
[
0
]
.
owner
.
op
)
_logger
.
warning
(
(
"Rewrite warning: "
f
"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
return
False
# create the composite op.
composite_op
=
aes
.
Composite
(
s_inputs
,
s_new_out
)
# create the new node.
# Do not call make_node to have test_value
new_node
=
maker
(
node
,
composite_op
)(
*
inputs
)
.
owner
assert
len
(
new_node
.
outputs
)
==
1
assert
node
.
outputs
[
0
]
.
type
.
dtype
==
new_node
.
outputs
[
0
]
.
type
.
dtype
if
len
(
new_node
.
inputs
)
>
max_nb_input
:
_logger
.
warning
(
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return
False
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while
True
:
ret
=
local_fuse
(
fgraph
,
new_node
)
if
ret
is
not
False
and
ret
is
not
None
:
assert
len
(
ret
)
==
len
(
new_node
.
outputs
)
assert
len
(
ret
)
==
1
new_node
=
ret
[
0
]
.
owner
else
:
break
return
new_node
.
outputs
return
local_fuse
def
elemwise_max_input_fct
(
node
):
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs.
if
not
config
.
cxx
:
return
31
return
1024
local_elemwise_fusion
=
local_elemwise_fusion_op
(
Elemwise
,
elemwise_max_input_fct
)
class
FusionOptimizer
(
GraphRewriter
):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
"""
def
__init__
(
self
,
node_rewriter
):
super
()
.
__init__
()
self
.
node_rewriter
=
node_rewriter
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
did_something
=
True
nb_iter
=
0
nb_replacement
=
0
nb_inconsistency_replace
=
0
time_toposort
=
0
if
fgraph
.
profile
:
validate_before
=
fgraph
.
profile
.
validate_time
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
callback_before
=
fgraph
.
execute_callbacks_time
while
did_something
:
t0
=
time
.
time
()
nodelist
=
list
(
fgraph
.
toposort
())
time_toposort
+=
time
.
time
()
-
t0
nodelist
.
reverse
()
did_something
=
False
for
node
in
nodelist
:
# Don't try to fuse node that have already been fused.
if
node
in
fgraph
.
apply_nodes
:
new_outputs
=
self
.
node_rewriter
(
fgraph
,
node
)
if
new_outputs
:
assert
len
(
new_outputs
)
==
len
(
node
.
outputs
)
try
:
fgraph
.
replace_all_validate
(
list
(
zip
(
node
.
outputs
,
new_outputs
)),
reason
=
self
.
__class__
.
__name__
,
)
did_something
=
True
nb_replacement
+=
1
except
InconsistencyError
:
nb_inconsistency_replace
+=
1
nb_iter
+=
1
if
fgraph
.
profile
:
validate_time
=
fgraph
.
profile
.
validate_time
-
validate_before
callback_time
=
fgraph
.
execute_callbacks_time
-
callback_before
callbacks_time
=
{}
for
k
,
v
in
fgraph
.
execute_callbacks_times
.
items
():
if
k
in
callbacks_before
:
callbacks_time
[
k
]
=
v
-
callbacks_before
[
k
]
else
:
callbacks_time
[
k
]
=
v
else
:
validate_time
=
None
callback_time
=
None
callbacks_time
=
{}
return
(
self
,
nb_iter
,
nb_replacement
,
nb_inconsistency_replace
,
validate_time
,
callback_time
,
callbacks_time
,
time_toposort
,
)
@classmethod
def
print_profile
(
cls
,
stream
,
prof
,
level
=
0
):
blanc
=
" "
*
level
print
(
blanc
,
cls
.
__name__
,
file
=
stream
)
print
(
blanc
,
" nb_iter"
,
prof
[
1
],
file
=
stream
)
print
(
blanc
,
" nb_replacement"
,
prof
[
2
],
file
=
stream
)
print
(
blanc
,
" nb_inconsistency_replace"
,
prof
[
3
],
file
=
stream
)
print
(
blanc
,
" validate_time"
,
prof
[
4
],
file
=
stream
)
print
(
blanc
,
" callback_time"
,
prof
[
5
],
file
=
stream
)
if
prof
[
5
]
is
not
None
and
prof
[
5
]
>
1
:
print
(
blanc
,
" callbacks_time"
,
file
=
stream
)
for
i
in
sorted
(
prof
[
6
]
.
items
(),
key
=
lambda
a
:
a
[
1
])[::
-
1
]:
if
i
[
1
]
>
0
:
print
(
blanc
,
" "
,
i
)
print
(
blanc
,
" time_toposort"
,
prof
[
7
],
file
=
stream
)
if
config
.
tensor__local_elemwise_fusion
:
_logger
.
debug
(
"Enabling Elemwise fusion rewriters in fast_run"
)
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt
=
SequenceDB
()
fuse_seqopt
.
register
(
"composite_elemwise_fusion"
,
FusionOptimizer
(
local_elemwise_fusion
),
"fast_run"
,
"fusion"
,
position
=
1
,
)
compile
.
optdb
.
register
(
"elemwise_fusion"
,
fuse_seqopt
,
"fast_run"
,
"fusion"
,
"local_elemwise_fusion"
,
"FusionOptimizer"
,
position
=
49
,
)
else
:
_logger
.
debug
(
"Not enabling Elemwise fusion rewriters in fast_run"
)
compile
.
optdb
.
register
(
"elemwise_fusion"
,
FusionOptimizer
(
local_elemwise_fusion
),
"fusion"
,
"local_elemwise_fusion"
,
"FusionOptimizer"
,
position
=
49
,
)
@register_canonicalize
@node_rewriter
([
Elemwise
])
def
local_useless_composite
(
fgraph
,
node
):
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
"""
if
not
isinstance
(
node
.
op
,
Elemwise
)
or
not
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Composite
):
return
comp
=
node
.
op
.
scalar_op
idx
=
[
i
for
i
,
o_extern
in
enumerate
(
node
.
outputs
)
if
fgraph
.
clients
[
o_extern
]]
if
len
(
idx
)
<
len
(
node
.
outputs
):
new_outputs
=
[
comp
.
outputs
[
i
]
for
i
in
idx
]
c
=
aes
.
Composite
(
inputs
=
comp
.
inputs
,
outputs
=
new_outputs
)
e
=
Elemwise
(
scalar_op
=
c
)(
*
node
.
inputs
,
return_list
=
True
)
return
dict
(
zip
([
node
.
outputs
[
i
]
for
i
in
idx
],
e
))
@register_canonicalize
(
"fast_compile"
)
@register_useless
(
"fast_compile"
)
@node_rewriter
(
None
)
...
...
@@ -3325,240 +1265,32 @@ def local_useless_topk(fgraph, node):
return
{
old_output
:
new_output
}
@register_useless
@register_canonicalize
@node_rewriter
([
SpecifyShape
])
def
local_merge_consecutive_specify_shape
(
fgraph
,
node
):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
"""
if
not
isinstance
(
node
.
op
,
SpecifyShape
):
return
False
obj
=
node
.
inputs
[
0
]
if
not
(
obj
.
owner
and
isinstance
(
obj
.
owner
.
op
,
SpecifyShape
)):
return
False
inner_obj
,
*
shape
=
obj
.
owner
.
inputs
for
dim
,
sh
in
enumerate
(
node
.
inputs
[
1
:]):
if
not
NoneConst
.
equals
(
sh
):
shape
[
dim
]
=
sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
# the same.
return
[
specify_shape
(
inner_obj
,
shape
)]
@register_useless
@register_canonicalize
@node_rewriter
([
Shape
])
def
local_Shape_of_SpecifyShape
(
fgraph
,
node
):
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
if
not
isinstance
(
node
.
op
,
Shape
):
return
False
specified_shape
=
node
.
inputs
[
0
]
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
return
False
x
,
*
shape
=
specified_shape
.
owner
.
inputs
# Replace `NoneConst` by `shape_i`
for
i
,
sh
in
enumerate
(
shape
):
if
NoneConst
.
equals
(
sh
):
shape
[
i
]
=
shape_i
(
x
,
i
,
fgraph
)
return
[
stack
(
shape
)
.
astype
(
np
.
int64
)]
@register_useless
@register_canonicalize
@node_rewriter
([
Shape_i
])
def
local_Shape_i_of_broadcastable
(
fgraph
,
node
):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
if
not
isinstance
(
node
.
op
,
Shape_i
):
return
False
shape_arg
=
node
.
inputs
[
0
]
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
if
shape_arg
.
broadcastable
[
node
.
op
.
i
]:
return
[
as_tensor_variable
(
1
,
dtype
=
np
.
int64
)]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_scalar
(
fgraph
,
node
):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
def
import_ShapeFeature
():
from
aesara.tensor.rewriting.shape
import
ShapeFeature
if
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
:
return
False
return
ShapeFeature
uniqued_var
=
node
.
inputs
[
0
]
if
uniqued_var
.
ndim
!=
0
:
return
False
DEPRECATED_NAMES
=
{
"ShapeFeature"
:
(
"`ShapeFeature` is now located in `aesara.tensor.rewriting.shape`."
,
import_ShapeFeature
,
),
}
old_out
=
node
.
outputs
[
0
]
res
=
as_tensor_variable
(
uniqued_var
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
res
]
def
__getattr__
(
name
):
"""Intercept module-level attribute access of deprecated symbols.
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_Alloc_lift
(
fgraph
,
node
):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
Adapted from https://stackoverflow.com/a/55139609/3006474.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
alloc_var
=
node
.
inputs
[
0
]
if
not
(
alloc_var
.
owner
and
isinstance
(
alloc_var
.
owner
.
op
,
Alloc
)):
return
False
alloced_var
,
*
alloc_shape
=
alloc_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
alloced_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_BroadcastTo_lift
(
fgraph
,
node
):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
bcast_var
=
node
.
inputs
[
0
]
if
not
(
bcast_var
.
owner
and
isinstance
(
bcast_var
.
owner
.
op
,
BroadcastTo
)):
return
False
bcasted_var
,
*
bcast_shape
=
bcast_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
bcasted_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_Repeat_lift
(
fgraph
,
node
):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
repeat_var
=
node
.
inputs
[
0
]
if
not
(
repeat_var
.
owner
and
isinstance
(
repeat_var
.
owner
.
op
,
Repeat
)):
return
False
repeated_var
,
*
repeat_shape
=
repeat_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
repeated_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_second
(
fgraph
,
node
):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
second_var
=
node
.
inputs
[
0
]
if
not
(
second_var
.
owner
and
isinstance
(
second_var
.
owner
.
op
,
Elemwise
)
and
isinstance
(
second_var
.
owner
.
op
.
scalar_op
,
aes
.
Second
)
):
return
False
shape_var
,
seconded_var
=
second_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
seconded_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
BroadcastTo
])
def
local_remove_scalar_BroadcastTo
(
fgraph
,
node
):
from
warnings
import
warn
bcast_shape
=
node
.
inputs
[
1
:]
res
=
DEPRECATED_NAMES
.
get
(
name
)
if
res
:
msg
,
fn
=
res
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
return
fn
()
if
not
bcast_shape
:
bcasted_var
=
node
.
inputs
[
0
]
# If this isn't true, the graph is invalid
assert
bcasted_var
.
ndim
==
0
return
[
bcasted_var
]
raise
AttributeError
(
f
"module {__name__} has no attribute {name}"
)
aesara/tensor/rewriting/elemwise.py
0 → 100644
浏览文件 @
63f52536
import
sys
import
time
from
collections
import
defaultdict
from
typing
import
Optional
from
warnings
import
warn
import
aesara
import
aesara.scalar.basic
as
aes
from
aesara
import
compile
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Constant
,
io_toposort
from
aesara.graph.features
import
ReplaceValidate
from
aesara.graph.op
import
compute_test_value
,
get_test_value
from
aesara.graph.rewriting.basic
import
GraphRewriter
,
copy_stack_trace
,
node_rewriter
from
aesara.graph.rewriting.db
import
SequenceDB
from
aesara.graph.utils
import
InconsistencyError
,
MethodNotDefined
,
TestValueError
from
aesara.tensor.basic
import
MakeVector
,
alloc
,
cast
,
get_scalar_constant_value
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.rewriting.basic
import
register_canonicalize
,
register_specialize
from
aesara.tensor.shape
import
shape_padleft
from
aesara.tensor.var
import
TensorConstant
class
InplaceElemwiseOptimizer
(
GraphRewriter
):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
def
__init__
(
self
,
OP
):
self
.
op
=
OP
def
add_requirements
(
self
,
fgraph
):
from
aesara.graph.destroyhandler
import
DestroyHandler
fgraph
.
attach_feature
(
DestroyHandler
())
@classmethod
def
print_profile
(
cls
,
stream
,
prof
,
level
=
0
):
blanc
=
" "
*
level
print
(
blanc
,
cls
.
__name__
,
prof
[
"opt"
]
.
op
,
file
=
stream
)
for
k
in
[
"node_before"
,
"nb_call_replace"
,
"nb_call_validate"
,
"nb_inconsistent"
,
]:
print
(
blanc
,
k
,
prof
[
k
],
file
=
stream
)
ndim
=
prof
[
"ndim"
]
if
ndim
:
print
(
blanc
,
"ndim"
,
"nb"
,
file
=
stream
)
for
n
in
sorted
(
ndim
.
keys
()):
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
def
apply
(
self
,
fgraph
):
r"""
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Examples
--------
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
# We should not validate too often as this takes too much time to
# execute!
# It is the _dfs_toposort() fct in aesara/graph/destroyhandler.py
# that takes so much time.
# Should we try to use another lib that does toposort?
# igraph: http://igraph.sourceforge.net/
# networkx: https://networkx.lanl.gov/
# Should we try to use cython?
# Compiling only that fct is not enough, should we try to add the
# deque class too?
# And init the deque and other list to an upper bound number of
# elements?
# Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic
#
# The next longest rewriter is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
# We execute `validate` after this number of change.
prof
=
{
"opt"
:
self
,
"node_before"
:
len
(
fgraph
.
apply_nodes
),
"nb_call_replace"
:
0
,
"nb_call_validate"
:
0
,
"nb_inconsistent"
:
0
,
"ndim"
:
defaultdict
(
lambda
:
0
),
}
check_each_change
=
config
.
tensor__insert_inplace_optimizer_validate_nb
if
check_each_change
==
-
1
:
if
len
(
fgraph
.
apply_nodes
)
>
500
:
check_each_change
=
10
else
:
check_each_change
=
1
nb_change_no_validate
=
0
chk
=
fgraph
.
checkpoint
()
if
fgraph
.
update_mapping
:
update_outs
=
[
fgraph
.
outputs
[
i
]
for
i
in
fgraph
.
update_mapping
]
else
:
update_outs
=
[]
protected_inputs
=
[
f
.
protected
for
f
in
fgraph
.
_features
if
isinstance
(
f
,
aesara
.
compile
.
function
.
types
.
Supervisor
)
]
protected_inputs
=
sum
(
protected_inputs
,
[])
# flatten the list
protected_inputs
.
extend
(
fgraph
.
outputs
)
for
node
in
list
(
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
)):
op
=
node
.
op
if
not
isinstance
(
op
,
self
.
op
):
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
if
(
check_each_change
!=
1
and
# If multiple outputs, they must all have the same size,
# so only check the first.
getattr
(
node
.
outputs
[
0
]
.
type
,
"ndim"
,
-
1
)
==
0
):
continue
if
op
.
inplace_pattern
:
# Maybe this isn't needed anymore, but I don't want to
# rish regression now. This case only happen if the
# original node add already some inplace patter and we
# still try to add more pattern.
baseline
=
op
.
inplace_pattern
candidate_outputs
=
[
i
for
i
in
range
(
len
(
node
.
outputs
))
if
i
not
in
baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
# inplace target.
# Remove here as faster.
candidate_inputs
=
[
i
for
i
in
range
(
len
(
node
.
inputs
))
if
i
not
in
baseline
.
values
()
and
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
# the next line should not be costly most of the time.
not
fgraph
.
has_destroyers
([
node
.
inputs
[
i
]])
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
else
:
baseline
=
[]
candidate_outputs
=
list
(
range
(
len
(
node
.
outputs
)))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs
=
[
i
for
i
in
range
(
len
(
node
.
inputs
))
if
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
not
fgraph
.
has_destroyers
([
node
.
inputs
[
i
]])
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
verbose
=
False
raised_warning
=
not
verbose
for
candidate_output
in
candidate_outputs
:
# If the output of the node can be established as an update
# output of the fgraph, visit the candidate_inputs in an order
# that will improve the chances of making the node operate
# inplace on the input it's meant to update
candidate_out_var
=
node
.
outputs
[
candidate_output
]
sorted_candidate_inputs
=
candidate_inputs
if
candidate_out_var
in
update_outs
:
# The candidate output is an update. Sort the
# variables in candidate_inputs in the following order:
# - Vars corresponding to the actual updated input
# (best case scenario is for the node that procudes
# an update to operate inplace on the variable to
# update)
# - Vars computed inplace on the updates input (second
# best scenario if for the node to work inplace on
# a variable obtained by a chain of inplace on the
# variable to update. In some cases, this will be
# equivalent to operating inplace on the variable to
# update)
# - Remaining variables
updated_inputs
=
[]
for
i
,
f_out
in
enumerate
(
fgraph
.
outputs
):
if
f_out
is
candidate_out_var
and
i
in
fgraph
.
update_mapping
:
updated_inp_idx
=
fgraph
.
update_mapping
[
i
]
updated_inputs
.
append
(
fgraph
.
inputs
[
updated_inp_idx
])
updated_vars
=
[]
vars_from_inplace
=
[]
other_vars
=
[]
for
inp_idx
in
candidate_inputs
:
inp
=
node
.
inputs
[
inp_idx
]
if
inp
in
updated_inputs
:
# the candidate input is the actual updated input
updated_vars
.
append
(
inp_idx
)
elif
(
hasattr
(
fgraph
,
"destroy_handler"
)
and
inp
.
owner
and
any
(
fgraph
.
destroy_handler
.
root_destroyer
.
get
(
up_inp
,
None
)
is
inp
.
owner
for
up_inp
in
updated_inputs
)
):
# the candidate input is a variable computed
# inplace on the updated input via a sequence of
# one or more inplace operations
vars_from_inplace
.
append
(
inp_idx
)
else
:
other_vars
.
append
(
inp_idx
)
sorted_candidate_inputs
=
(
updated_vars
+
vars_from_inplace
+
other_vars
)
for
candidate_input
in
sorted_candidate_inputs
:
# remove inputs that don't have the same dtype as the output
if
(
node
.
inputs
[
candidate_input
]
.
type
!=
node
.
outputs
[
candidate_output
]
.
type
):
continue
inplace_pattern
=
dict
(
baseline
)
inplace_pattern
[
candidate_output
]
=
candidate_input
try
:
if
hasattr
(
op
.
scalar_op
,
"make_new_inplace"
):
new_scal
=
op
.
scalar_op
.
make_new_inplace
(
aes
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
o
.
dtype
)
for
i
,
o
in
enumerate
(
node
.
outputs
)
]
)
)
else
:
new_scal
=
op
.
scalar_op
.
__class__
(
aes
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
range
(
len
(
node
.
outputs
))
]
)
)
new_outputs
=
self
.
op
(
new_scal
,
inplace_pattern
)(
*
node
.
inputs
,
return_list
=
True
)
new_node
=
new_outputs
[
0
]
.
owner
for
r
,
new_r
in
zip
(
node
.
outputs
,
new_outputs
):
prof
[
"nb_call_replace"
]
+=
1
fgraph
.
replace
(
r
,
new_r
,
reason
=
"inplace_elemwise_optimizer"
)
nb_change_no_validate
+=
1
prof
[
"ndim"
][
candidate_out_var
.
ndim
]
+=
1
if
nb_change_no_validate
>=
check_each_change
:
prof
[
"nb_call_validate"
]
+=
1
fgraph
.
validate
()
chk
=
fgraph
.
checkpoint
()
nb_change_no_validate
=
0
except
(
ValueError
,
InconsistencyError
)
as
e
:
prof
[
"nb_inconsistent"
]
+=
1
if
check_each_change
!=
1
and
not
raised_warning
:
print
(
(
"Some inplace rewriting was not "
"performed due to an unexpected error:"
),
file
=
sys
.
stderr
,
)
print
(
e
,
file
=
sys
.
stderr
)
raised_warning
=
True
fgraph
.
revert
(
chk
)
continue
candidate_inputs
.
remove
(
candidate_input
)
node
=
new_node
baseline
=
inplace_pattern
break
if
nb_change_no_validate
>
0
:
try
:
fgraph
.
validate
()
except
Exception
:
if
not
raised_warning
:
print
(
(
"Some inplace rewriting was not "
"performed due to an unexpected error"
),
file
=
sys
.
stderr
,
)
fgraph
.
revert
(
chk
)
return
prof
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
f
"{' ' * level}{self.__class__.__name__} ({self.op})"
,
file
=
stream
,
)
return
inplace_elemwise_optimizer
inplace_elemwise_optimizer
=
InplaceElemwiseOptimizer
(
Elemwise
)
compile
.
optdb
.
register
(
# type: ignore
"inplace_elemwise_opt"
,
inplace_elemwise_optimizer
,
"inplace_opt"
,
# for historic reason
"inplace_elemwise_optimizer"
,
"fast_run"
,
"inplace"
,
position
=
75
,
)
def
apply_local_dimshuffle_lift
(
fgraph
,
var
):
"""
lift recursively
"""
if
not
var
.
owner
:
return
var
new
=
local_dimshuffle_lift
.
transform
(
fgraph
,
var
.
owner
)
if
new
:
return
new
[
0
]
return
var
def
is_dimshuffle_useless
(
new_order
,
input
):
"""
Checks for two types of useless dimshuffles:
1 - dimshuffle all dimensions in order.
2 - dimshuffle a broadcastable dimension.
"""
is_useless
=
True
if
len
(
new_order
)
==
input
.
type
.
ndim
:
all_broadcastable_dims
=
[
i
for
(
i
,
is_broadcastable
)
in
enumerate
(
input
.
type
.
broadcastable
)
if
is_broadcastable
]
+
[
"x"
]
for
i
in
range
(
input
.
type
.
ndim
):
if
new_order
[
i
]
==
i
or
(
i
in
all_broadcastable_dims
and
new_order
[
i
]
in
all_broadcastable_dims
):
is_useless
=
True
else
:
is_useless
=
False
break
else
:
is_useless
=
False
return
is_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
DimShuffle
])
def
local_dimshuffle_lift
(
fgraph
,
node
):
"""
"Lifts" DimShuffle through Elemwise operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing)
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
"""
op
=
node
.
op
if
not
isinstance
(
op
,
DimShuffle
):
return
False
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
new_order
=
op
.
new_order
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
(
len
(
fgraph
.
clients
[
inp
])
==
1
):
# Don't use make_node to have tag.test_value set.
new_inputs
=
[]
for
inp
in
inode
.
inputs
:
new_inp
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
op
.
new_order
)(
inp
)
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
return
ret
if
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
new_order
=
[
x
==
"x"
and
"x"
or
inode
.
op
.
new_order
[
x
]
for
x
in
new_order
]
inp
=
inode
.
inputs
[
0
]
if
is_dimshuffle_useless
(
new_order
,
inp
):
return
[
inp
]
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
ret
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
new_order
)(
inp
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
@register_canonicalize
@register_specialize
@node_rewriter
([
DimShuffle
])
def
local_useless_dimshuffle_makevector
(
fgraph
,
node
):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
This rewrite is needed in order to clean up after
`local_subtensor_remove_broadcastable_index`, which produces a
not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)`
(i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`.
"""
# The `DimShuffle` should be removing the single broadcastable dimension
if
node
.
op
.
new_order
!=
():
return
makevector_out
=
node
.
inputs
[
0
]
if
(
not
makevector_out
.
owner
or
not
isinstance
(
makevector_out
.
owner
.
op
,
MakeVector
)
or
not
makevector_out
.
broadcastable
==
(
True
,)
):
return
assert
len
(
makevector_out
.
owner
.
inputs
)
==
1
return
[
makevector_out
.
owner
.
inputs
[
0
]]
@register_canonicalize
@node_rewriter
([
Elemwise
])
def
local_upcast_elemwise_constant_inputs
(
fgraph
,
node
):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
if
len
(
node
.
outputs
)
>
1
:
return
try
:
shape_i
=
fgraph
.
shape_feature
.
shape_i
except
AttributeError
:
shape_i
=
None
if
isinstance
(
node
.
op
,
Elemwise
):
scalar_op
=
node
.
op
.
scalar_op
# print "aa", scalar_op.output_types_preference
if
getattr
(
scalar_op
,
"output_types_preference"
,
None
)
in
(
aes
.
upgrade_to_float
,
aes
.
upcast_out
,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
new_inputs
=
[]
for
i
in
node
.
inputs
:
if
i
.
type
.
dtype
==
output_dtype
:
new_inputs
.
append
(
i
)
else
:
try
:
# works only for scalars
cval_i
=
get_scalar_constant_value
(
i
,
only_process_constants
=
True
)
if
all
(
i
.
broadcastable
):
new_inputs
.
append
(
shape_padleft
(
cast
(
cval_i
,
output_dtype
),
i
.
ndim
)
)
else
:
if
shape_i
is
None
:
return
new_inputs
.
append
(
alloc
(
cast
(
cval_i
,
output_dtype
),
*
[
shape_i
(
d
)(
i
)
for
d
in
range
(
i
.
ndim
)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except
NotScalarConstantError
:
# for the case of a non-scalar
if
isinstance
(
i
,
TensorConstant
):
new_inputs
.
append
(
cast
(
i
,
output_dtype
))
else
:
new_inputs
.
append
(
i
)
if
new_inputs
!=
node
.
inputs
:
rval
=
[
node
.
op
(
*
new_inputs
)]
if
not
node
.
outputs
[
0
]
.
type
.
is_super
(
rval
[
0
]
.
type
):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
# Copy over output stacktrace from before upcasting
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
return
rval
def
local_elemwise_fusion_op
(
op_class
,
max_input_fct
=
lambda
node
:
32
,
maker
=
None
):
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` `Op`\s.
Parameters
----------
op_class : type
`Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature ``(node, *args)`` that constructs an
`op_class` instance (e.g. ``op_class(*args)``).
"""
if
maker
is
None
:
def
maker
(
node
,
scalar_op
):
return
op_class
(
scalar_op
)
def
local_fuse
(
fgraph
,
node
):
r"""Fuse `Elemwise` `Op`\s in a node.
As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the
same shape.
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by Aesara itself.
"""
# TODO: use broadcast flag?
# TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
# TODO: use malloc and copy to transfer arguments that don't
# fit within the parameter space of 256 bytes
#
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
# TODO: Use Composite to combine Elemwise and Reduce
# operations. We have to loop over the data anyway... might
# as well sum it up while we're at it (this can be trickier
# than i'm making it seound here. The data-traversal should be
# done contiguously, and the summing-up might not be easy or
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if
type
(
node
.
op
)
is
not
op_class
:
return
False
if
len
(
node
.
outputs
)
>
1
:
# We don't support fusion for nodes with multiple outputs.
return
inputs
=
[]
# inputs of the new Elemwise op.
s_inputs
=
[]
# inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
s_g
=
[]
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input
=
max_input_fct
(
node
)
# The number of inputs to the new fused op if we do not fuse more
# inputs.
new_nb_input
=
len
(
node
.
inputs
)
# Did we fuse something?
# Needed as we can fuse unary op that don't change the number of
# inputs.
# And there is a case where the inputs are the same as the current
# node. That won't change the number of inputs of the new op.
fused
=
False
for
i
in
node
.
inputs
:
scalar_node
:
Optional
[
Apply
]
=
None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input
=
[]
# Same as tmp_input, but for scalars.
tmp_scalar
=
[]
# We should not check the number of inputs here
# As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
op_class
)
and
len
({
n
for
n
,
idx
in
fgraph
.
clients
[
i
]})
==
1
and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i
.
owner
.
outputs
[
0
]
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
):
try
:
tmp_s_input
=
[]
# we should not put duplicate input into s_inputs and inputs
for
ii
in
i
.
owner
.
inputs
:
if
ii
in
inputs
:
tmp_s_input
.
append
(
s_inputs
[
inputs
.
index
(
ii
)])
elif
ii
in
tmp_input
:
tmp_s_input
.
append
(
tmp_scalar
[
tmp_input
.
index
(
ii
)])
else
:
tmp
=
aes
.
get_scalar_type
(
ii
.
type
.
dtype
)
.
make_variable
()
try
:
tv
=
get_test_value
(
ii
)
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp
.
tag
.
test_value
=
tv
.
item
()
except
(
TestValueError
,
ValueError
):
pass
tmp_s_input
.
append
(
tmp
)
tmp_input
.
append
(
ii
)
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node
=
i
.
owner
.
op
.
scalar_op
.
make_node
(
*
tmp_s_input
)
if
config
.
compute_test_value_opt
!=
"off"
:
# This is required because `Op.make_node` won't do it
compute_test_value
(
scalar_node
)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i
.
owner
.
op
.
scalar_op
.
c_code
(
scalar_node
,
"test_presence_of_c_code"
,
[
"x"
for
x
in
i
.
owner
.
inputs
],
[
"z"
for
z
in
i
.
owner
.
outputs
],
{
"fail"
:
"
%(fail)
s"
},
)
except
(
NotImplementedError
,
MethodNotDefined
):
warn
(
(
"Rewrite warning: "
f
"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
scalar_node
=
None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_
=
new_nb_input
+
len
(
tmp_input
)
-
1
# If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for
x
in
tmp_input
:
if
x
in
node
.
inputs
:
new_nb_input_
-=
1
if
scalar_node
and
(
new_nb_input_
<=
max_nb_input
):
fused
=
True
new_nb_input
=
new_nb_input_
inputs
.
extend
(
tmp_input
)
s_inputs
.
extend
(
tmp_scalar
)
s_g
.
extend
(
scalar_node
.
outputs
)
else
:
# We must support the case where the same variable appears many
# times within the inputs
if
inputs
.
count
(
i
)
==
node
.
inputs
.
count
(
i
):
s
=
s_inputs
[
inputs
.
index
(
i
)]
else
:
s
=
aes
.
get_scalar_type
(
i
.
type
.
dtype
)
.
make_variable
()
if
config
.
compute_test_value_opt
!=
"off"
:
try
:
v
=
get_test_value
(
i
)
# See the zero-dimensional test value situation
# described above.
s
.
tag
.
test_value
=
v
.
item
()
except
(
TestValueError
,
ValueError
):
pass
inputs
.
append
(
i
)
s_inputs
.
append
(
s
)
s_g
.
append
(
s
)
if
not
fused
:
return
False
if
new_nb_input
!=
len
(
inputs
)
or
len
(
s_inputs
)
!=
len
(
inputs
):
# TODO FIXME: This shouldn't be a generic `Exception`
raise
Exception
(
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out
=
node
.
op
.
scalar_op
(
*
s_g
,
return_list
=
True
)
try
:
s_new_out
[
0
]
.
owner
.
op
.
c_code
(
s_new_out
[
0
]
.
owner
,
"test_presence_of_c_code"
,
[
"x"
for
x
in
s_g
],
[
"z"
for
x
in
s_new_out
],
{
"fail"
:
"
%(fail)
s"
},
)
except
(
NotImplementedError
,
MethodNotDefined
):
name
=
str
(
s_new_out
[
0
]
.
owner
.
op
)
warn
(
(
"Rewrite warning: "
f
"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
)
return
False
# create the composite op.
composite_op
=
aes
.
Composite
(
s_inputs
,
s_new_out
)
# create the new node.
# Do not call make_node to have test_value
new_node
=
maker
(
node
,
composite_op
)(
*
inputs
)
.
owner
assert
len
(
new_node
.
outputs
)
==
1
assert
node
.
outputs
[
0
]
.
type
.
dtype
==
new_node
.
outputs
[
0
]
.
type
.
dtype
if
len
(
new_node
.
inputs
)
>
max_nb_input
:
warn
(
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return
False
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while
True
:
ret
=
local_fuse
(
fgraph
,
new_node
)
if
ret
is
not
False
and
ret
is
not
None
:
assert
len
(
ret
)
==
len
(
new_node
.
outputs
)
assert
len
(
ret
)
==
1
new_node
=
ret
[
0
]
.
owner
else
:
break
return
new_node
.
outputs
return
local_fuse
def
elemwise_max_input_fct
(
node
):
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs.
if
not
config
.
cxx
:
return
31
return
1024
local_elemwise_fusion
=
local_elemwise_fusion_op
(
Elemwise
,
elemwise_max_input_fct
)
class
FusionOptimizer
(
GraphRewriter
):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
"""
def
__init__
(
self
,
node_rewriter
):
super
()
.
__init__
()
self
.
node_rewriter
=
node_rewriter
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
def
apply
(
self
,
fgraph
):
did_something
=
True
nb_iter
=
0
nb_replacement
=
0
nb_inconsistency_replace
=
0
time_toposort
=
0
if
fgraph
.
profile
:
validate_before
=
fgraph
.
profile
.
validate_time
callbacks_before
=
fgraph
.
execute_callbacks_times
.
copy
()
callback_before
=
fgraph
.
execute_callbacks_time
while
did_something
:
t0
=
time
.
time
()
nodelist
=
list
(
fgraph
.
toposort
())
time_toposort
+=
time
.
time
()
-
t0
nodelist
.
reverse
()
did_something
=
False
for
node
in
nodelist
:
# Don't try to fuse node that have already been fused.
if
node
in
fgraph
.
apply_nodes
:
new_outputs
=
self
.
node_rewriter
(
fgraph
,
node
)
if
new_outputs
:
assert
len
(
new_outputs
)
==
len
(
node
.
outputs
)
try
:
fgraph
.
replace_all_validate
(
list
(
zip
(
node
.
outputs
,
new_outputs
)),
reason
=
self
.
__class__
.
__name__
,
)
did_something
=
True
nb_replacement
+=
1
except
InconsistencyError
:
nb_inconsistency_replace
+=
1
nb_iter
+=
1
if
fgraph
.
profile
:
validate_time
=
fgraph
.
profile
.
validate_time
-
validate_before
callback_time
=
fgraph
.
execute_callbacks_time
-
callback_before
callbacks_time
=
{}
for
k
,
v
in
fgraph
.
execute_callbacks_times
.
items
():
if
k
in
callbacks_before
:
callbacks_time
[
k
]
=
v
-
callbacks_before
[
k
]
else
:
callbacks_time
[
k
]
=
v
else
:
validate_time
=
None
callback_time
=
None
callbacks_time
=
{}
return
(
self
,
nb_iter
,
nb_replacement
,
nb_inconsistency_replace
,
validate_time
,
callback_time
,
callbacks_time
,
time_toposort
,
)
@classmethod
def
print_profile
(
cls
,
stream
,
prof
,
level
=
0
):
blanc
=
" "
*
level
print
(
blanc
,
cls
.
__name__
,
file
=
stream
)
print
(
blanc
,
" nb_iter"
,
prof
[
1
],
file
=
stream
)
print
(
blanc
,
" nb_replacement"
,
prof
[
2
],
file
=
stream
)
print
(
blanc
,
" nb_inconsistency_replace"
,
prof
[
3
],
file
=
stream
)
print
(
blanc
,
" validate_time"
,
prof
[
4
],
file
=
stream
)
print
(
blanc
,
" callback_time"
,
prof
[
5
],
file
=
stream
)
if
prof
[
5
]
is
not
None
and
prof
[
5
]
>
1
:
print
(
blanc
,
" callbacks_time"
,
file
=
stream
)
for
i
in
sorted
(
prof
[
6
]
.
items
(),
key
=
lambda
a
:
a
[
1
])[::
-
1
]:
if
i
[
1
]
>
0
:
print
(
blanc
,
" "
,
i
)
print
(
blanc
,
" time_toposort"
,
prof
[
7
],
file
=
stream
)
if
config
.
tensor__local_elemwise_fusion
:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt
=
SequenceDB
()
fuse_seqopt
.
register
(
"composite_elemwise_fusion"
,
FusionOptimizer
(
local_elemwise_fusion
),
"fast_run"
,
"fusion"
,
position
=
1
,
)
compile
.
optdb
.
register
(
# type: ignore
"elemwise_fusion"
,
fuse_seqopt
,
"fast_run"
,
"fusion"
,
"local_elemwise_fusion"
,
"FusionOptimizer"
,
position
=
49
,
)
else
:
compile
.
optdb
.
register
(
# type: ignore
"elemwise_fusion"
,
FusionOptimizer
(
local_elemwise_fusion
),
"fusion"
,
"local_elemwise_fusion"
,
"FusionOptimizer"
,
position
=
49
,
)
@register_canonicalize
@node_rewriter
([
Elemwise
])
def
local_useless_composite
(
fgraph
,
node
):
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
"""
if
not
isinstance
(
node
.
op
,
Elemwise
)
or
not
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Composite
):
return
comp
=
node
.
op
.
scalar_op
idx
=
[
i
for
i
,
o_extern
in
enumerate
(
node
.
outputs
)
if
fgraph
.
clients
[
o_extern
]]
if
len
(
idx
)
<
len
(
node
.
outputs
):
new_outputs
=
[
comp
.
outputs
[
i
]
for
i
in
idx
]
c
=
aes
.
Composite
(
inputs
=
comp
.
inputs
,
outputs
=
new_outputs
)
e
=
Elemwise
(
scalar_op
=
c
)(
*
node
.
inputs
,
return_list
=
True
)
return
dict
(
zip
([
node
.
outputs
[
i
]
for
i
in
idx
],
e
))
aesara/tensor/rewriting/extra_ops.py
0 → 100644
浏览文件 @
63f52536
import
aesara.scalar.basic
as
aes
from
aesara.graph.rewriting.basic
import
node_rewriter
from
aesara.tensor.basic
import
Alloc
,
as_tensor_variable
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
from
aesara.tensor.rewriting.basic
import
register_canonicalize
,
register_useless
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_scalar
(
fgraph
,
node
):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
:
return
False
uniqued_var
=
node
.
inputs
[
0
]
if
uniqued_var
.
ndim
!=
0
:
return
False
old_out
=
node
.
outputs
[
0
]
res
=
as_tensor_variable
(
uniqued_var
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
res
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_Alloc_lift
(
fgraph
,
node
):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
alloc_var
=
node
.
inputs
[
0
]
if
not
(
alloc_var
.
owner
and
isinstance
(
alloc_var
.
owner
.
op
,
Alloc
)):
return
False
alloced_var
,
*
alloc_shape
=
alloc_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
alloced_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_BroadcastTo_lift
(
fgraph
,
node
):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
bcast_var
=
node
.
inputs
[
0
]
if
not
(
bcast_var
.
owner
and
isinstance
(
bcast_var
.
owner
.
op
,
BroadcastTo
)):
return
False
bcasted_var
,
*
bcast_shape
=
bcast_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
bcasted_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_Repeat_lift
(
fgraph
,
node
):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
repeat_var
=
node
.
inputs
[
0
]
if
not
(
repeat_var
.
owner
and
isinstance
(
repeat_var
.
owner
.
op
,
Repeat
)):
return
False
repeated_var
,
*
repeat_shape
=
repeat_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
repeated_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_second
(
fgraph
,
node
):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
second_var
=
node
.
inputs
[
0
]
if
not
(
second_var
.
owner
and
isinstance
(
second_var
.
owner
.
op
,
Elemwise
)
and
isinstance
(
second_var
.
owner
.
op
.
scalar_op
,
aes
.
Second
)
):
return
False
shape_var
,
seconded_var
=
second_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
seconded_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
BroadcastTo
])
def
local_remove_scalar_BroadcastTo
(
fgraph
,
node
):
bcast_shape
=
node
.
inputs
[
1
:]
if
not
bcast_shape
:
bcasted_var
=
node
.
inputs
[
0
]
# If this isn't true, the graph is invalid
assert
bcasted_var
.
ndim
==
0
return
[
bcasted_var
]
aesara/tensor/rewriting/math.py
浏览文件 @
63f52536
...
...
@@ -72,10 +72,8 @@ from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sq
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
true_div
from
aesara.tensor.rewriting.basic
import
(
FusionOptimizer
,
broadcast_like
,
encompasses_broadcastable
,
fuse_seqopt
,
local_fill_sink
,
register_canonicalize
,
register_specialize
,
...
...
@@ -84,6 +82,7 @@ from aesara.tensor.rewriting.basic import (
register_uncanonicalize
,
register_useless
,
)
from
aesara.tensor.rewriting.elemwise
import
FusionOptimizer
,
fuse_seqopt
from
aesara.tensor.shape
import
Shape
,
Shape_i
from
aesara.tensor.subtensor
import
Subtensor
from
aesara.tensor.type
import
(
...
...
aesara/tensor/rewriting/shape.py
0 → 100644
浏览文件 @
63f52536
import
traceback
from
io
import
StringIO
from
typing
import
Optional
from
typing
import
cast
as
type_cast
from
warnings
import
warn
import
numpy
as
np
import
aesara
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
Variable
,
ancestors
,
equal_computations
from
aesara.graph.features
import
AlreadyThere
,
Feature
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.rewriting.basic
import
(
GraphRewriter
,
RemovalNodeRewriter
,
check_chain
,
copy_stack_trace
,
node_rewriter
,
)
from
aesara.graph.utils
import
InconsistencyError
,
get_variable_trace_string
from
aesara.tensor.basic
import
(
MakeVector
,
as_tensor_variable
,
cast
,
constant
,
extract_constant
,
get_scalar_constant_value
,
stack
,
tensor_copy
,
)
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
,
ShapeError
from
aesara.tensor.rewriting.basic
import
(
register_canonicalize
,
register_specialize
,
register_stabilize
,
register_useless
,
topo_constant_folding
,
)
from
aesara.tensor.shape
import
(
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
,
shape_i
,
specify_shape
,
unbroadcast
,
)
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
from
aesara.tensor.type
import
TensorType
,
discrete_dtypes
,
integer_dtypes
from
aesara.tensor.type_other
import
NoneConst
class
ShapeFeature
(
Feature
):
r"""A `Feature` that tracks shape information in a graph.
This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with
`Shape_i` and `MakeVector` `Op`\s.
This `Feature` and its associated rewrites have several goals:
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the
input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute
something just to know how big it will be. Firstly, it is a waste
of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many rewrites refuse to
work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is
present, or else using a conservative default. An Op that
supports shape-lifting should define a infer_shape(self, fgraph, node,
input_shapes) function. The argument input_shapes is a tuple of
tuples... there is an interior tuple for each input to the node.
The tuple has as many elements as dimensions. The element in
position i of tuple j represents the i'th shape component of the
j'th input. The function should return a tuple of tuples. One
output tuple for each node.output. Again, the i'th element of the
j'th output tuple represents the output[j].shape[i] of the
function. If an output is not a TensorType, then None should be
returned instead of a tuple for that output.
For example the infer_shape for a matrix-matrix product would accept
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
for doing size-driven rewrites. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes.
In cases where you cannot figure out the shape, raise a ShapeError.
Notes
-----
Right now there is only the ConvOp that could really take
advantage of this shape inference, but it is worth it even
just for the ConvOp. All that's necessary to do shape
inference is 1) to mark shared inputs as having a particular
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
certain dimensions).
We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in rewrites, use the
``shape_of`` dictionary.
For example:
.. code-block:: python
try:
shape_of = fgraph.shape_feature.shape_of
except AttributeError:
# This can happen when the mode doesn't include the ShapeFeature.
return
shape_of_output_zero = shape_of[node.output[0]]
The ``shape_of_output_zero`` symbol will contain a tuple, whose
elements are either integers or symbolic integers.
TODO: check to see if the symbols are necessarily
non-constant... or are integer literals sometimes Aesara
constants?? That would be confusing.
"""
def
get_node_infer_shape
(
self
,
node
):
try
:
shape_infer
=
node
.
op
.
infer_shape
except
AttributeError
:
shape_infer
=
self
.
default_infer_shape
try
:
o_shapes
=
shape_infer
(
self
.
fgraph
,
node
,
[
self
.
shape_of
[
r
]
for
r
in
node
.
inputs
]
)
except
ShapeError
:
o_shapes
=
self
.
default_infer_shape
(
self
.
fgraph
,
node
,
[
self
.
shape_of
[
r
]
for
r
in
node
.
inputs
]
)
except
NotImplementedError
as
e
:
raise
NotImplementedError
(
"Code called by infer_shape failed raising a "
"NotImplementedError. Raising NotImplementedError to "
"indicate that a shape cannot be computed is no longer "
"supported, and one should now use ShapeError "
f
"instead. The original exception message is: {e}"
)
.
with_traceback
(
e
.
__traceback__
)
except
Exception
as
e
:
msg
=
(
f
"Failed to infer_shape from Op {node.op}.
\n
Input shapes: "
f
"{[self.shape_of[r] for r in node.inputs]}
\n
Exception encountered during infer_shape: "
f
"{type(e)}
\n
Exception message: {str(e)}
\n
Traceback: {traceback.format_exc()}"
)
if
config
.
on_shape_error
==
"raise"
:
raise
Exception
(
msg
)
.
with_traceback
(
e
.
__traceback__
)
else
:
warn
(
msg
)
o_shapes
=
self
.
default_infer_shape
(
self
.
fgraph
,
node
,
[
self
.
shape_of
[
r
]
for
r
in
node
.
inputs
]
)
return
o_shapes
def
get_shape
(
self
,
var
,
idx
):
"""Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly ``shape_of[var][idx]``
as this method should update `shape_of` if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
r
=
self
.
shape_of
[
var
][
idx
]
if
(
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
Shape_i
)
and
r
.
owner
.
inputs
[
0
]
not
in
self
.
fgraph
.
variables
):
assert
var
.
owner
node
=
var
.
owner
# recur on inputs
for
i
in
node
.
inputs
:
if
getattr
(
i
.
type
,
"ndim"
,
None
)
>
0
:
self
.
get_shape
(
i
,
0
)
o_shapes
=
self
.
get_node_infer_shape
(
node
)
assert
len
(
o_shapes
)
==
len
(
node
.
outputs
)
# Only change the variables and dimensions that would introduce
# extra computation
for
new_shps
,
out
in
zip
(
o_shapes
,
node
.
outputs
):
if
not
hasattr
(
out
.
type
,
"ndim"
):
continue
merged_shps
=
list
(
self
.
shape_of
[
out
])
changed
=
False
for
i
in
range
(
out
.
type
.
ndim
):
n_r
=
merged_shps
[
i
]
if
(
n_r
.
owner
and
isinstance
(
n_r
.
owner
.
op
,
Shape_i
)
and
n_r
.
owner
.
inputs
[
0
]
not
in
self
.
fgraph
.
variables
):
changed
=
True
merged_shps
[
i
]
=
new_shps
[
i
]
if
changed
:
self
.
set_shape
(
out
,
merged_shps
,
override
=
True
)
r
=
self
.
shape_of
[
var
][
idx
]
return
r
def
shape_ir
(
self
,
i
,
r
):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
if
hasattr
(
r
.
type
,
"shape"
)
and
r
.
type
.
shape
[
i
]
is
not
None
:
return
constant
(
r
.
type
.
shape
[
i
],
dtype
=
"int64"
)
else
:
# Do not call make_node for test_value
s
=
Shape_i
(
i
)(
r
)
try
:
s
=
get_scalar_constant_value
(
s
)
except
NotScalarConstantError
:
pass
return
s
def
shape_tuple
(
self
,
r
):
"""Return a tuple of symbolic shape vars for tensor variable r."""
if
not
hasattr
(
r
.
type
,
"ndim"
):
# This happen for NoneConst.
return
None
return
tuple
(
self
.
shape_ir
(
i
,
r
)
for
i
in
range
(
r
.
type
.
ndim
))
def
default_infer_shape
(
self
,
fgraph
,
node
,
i_shapes
):
"""Return a list of shape tuple or None for the outputs of node.
This function is used for Ops that don't implement infer_shape.
Ops that do implement infer_shape should use the i_shapes parameter,
but this default implementation ignores it.
"""
rval
=
[]
for
r
in
node
.
outputs
:
try
:
rval
.
append
(
self
.
shape_tuple
(
r
))
except
AttributeError
:
rval
.
append
(
None
)
return
rval
def
unpack
(
self
,
s_i
,
var
):
"""Return a symbolic integer scalar for the shape element s_i.
The s_i argument was produced by the infer_shape() of an Op subclass.
var: the variable that correspond to s_i. This is just for
error reporting.
"""
assert
s_i
is
not
None
if
s_i
==
1
:
return
self
.
lscalar_one
if
isinstance
(
s_i
,
float
)
and
int
(
s_i
)
==
s_i
:
s_i
=
int
(
s_i
)
if
isinstance
(
s_i
,
(
np
.
integer
,
int
))
or
(
isinstance
(
s_i
,
np
.
ndarray
)
and
s_i
.
ndim
==
0
):
# this shape is a constant
if
s_i
<
0
:
msg
=
"There is a negative shape in the graph!"
msg
+=
get_variable_trace_string
(
var
)
# The rest of the pipeline don't handle correctly this
# case. So we have 2 choices, stop compilation or
# consider the shape as unknown. As we have more
# chance to give the stack trace here then later, I
# choose that options as it would give better error
# message.
raise
AssertionError
(
msg
)
return
constant
(
s_i
,
dtype
=
"int64"
)
if
isinstance
(
s_i
,
(
tuple
,
list
)):
# this dimension is the same as many of the inputs
# which tells us that if one of the inputs is known,
# the others all become known.
# TODO: should be implemented in Elemwise, and Dot
#
# worst case, we loop over shape_of and replace things
raise
NotImplementedError
(
s_i
)
# s_i is x.shape[i] for some x, we change it to shape_of[x][i]
if
(
s_i
.
owner
and
isinstance
(
s_i
.
owner
.
op
,
Subtensor
)
and
s_i
.
owner
.
inputs
[
0
]
.
owner
and
isinstance
(
s_i
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape
)
):
assert
s_i
.
type
.
ndim
==
0
assert
len
(
s_i
.
owner
.
op
.
idx_list
)
==
1
# The current Subtensor always put constant index in the graph.
# This was not True in the past. So call the Subtensor function
# that will return the right index.
idx
=
get_idx_list
(
s_i
.
owner
.
inputs
,
s_i
.
owner
.
op
.
idx_list
)
assert
len
(
idx
)
==
1
idx
=
idx
[
0
]
try
:
i
=
get_scalar_constant_value
(
idx
)
except
NotScalarConstantError
:
pass
else
:
# Executed only if no exception was raised
x
=
s_i
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
# x should already have been imported, and should be in shape_of.
s_i
=
self
.
shape_of
[
x
][
i
]
if
s_i
.
type
.
dtype
in
integer_dtypes
:
if
getattr
(
s_i
.
type
,
"ndim"
,
0
):
raise
TypeError
(
"Shape element must be scalar"
,
s_i
)
return
s_i
else
:
raise
TypeError
(
"Unsupported shape element"
,
s_i
,
type
(
s_i
),
getattr
(
s_i
,
"type"
,
None
)
)
def
set_shape
(
self
,
r
,
s
,
override
=
False
):
"""Assign the shape `s` to previously un-shaped variable `r`.
Parameters
----------
r : a variable
s : None or a tuple of symbolic integers
override : If False, it mean r is a new object in the fgraph.
If True, it mean r is already in the fgraph and we want to
override its shape.
"""
if
not
override
:
assert
r
not
in
self
.
shape_of
,
"r already in shape_of"
if
s
is
None
:
self
.
shape_of
[
r
]
=
s
else
:
if
not
isinstance
(
s
,
(
tuple
,
list
)):
raise
TypeError
(
"shapes must be tuple/list"
,
(
r
,
s
))
if
r
.
type
.
ndim
!=
len
(
s
):
sio
=
StringIO
()
aesara
.
printing
.
debugprint
(
r
,
file
=
sio
,
print_type
=
True
)
raise
AssertionError
(
f
"Something inferred a shape with {len(s)} dimensions "
f
"for a variable with {int(r.type.ndim)} dimensions"
f
" for the variable:
\n
{sio.getvalue()}"
)
shape_vars
=
[]
for
i
in
range
(
r
.
type
.
ndim
):
if
hasattr
(
r
.
type
,
"shape"
)
and
r
.
type
.
shape
[
i
]
is
not
None
:
shape_vars
.
append
(
constant
(
r
.
type
.
shape
[
i
],
dtype
=
"int64"
))
else
:
shape_vars
.
append
(
self
.
unpack
(
s
[
i
],
r
))
assert
all
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
i
]
or
self
.
lscalar_one
.
equals
(
shape_vars
[
i
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
shape_vars
[
i
]))
for
i
in
range
(
r
.
type
.
ndim
)
)
self
.
shape_of
[
r
]
=
tuple
(
shape_vars
)
for
sv
in
shape_vars
:
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
update_shape
(
self
,
r
,
other_r
):
"""Replace shape of r by shape of other_r.
If, on some dimensions, the shape of other_r is not informative,
keep the shape of r on those dimensions.
"""
# other_r should already have a shape
assert
other_r
in
self
.
shape_of
,
(
"other_r not in shape_of"
,
other_r
)
other_shape
=
self
.
shape_of
[
other_r
]
# If other_shape has no information, call is pointless.
if
other_shape
is
None
:
return
if
r
in
self
.
shape_of
:
r_shape
=
self
.
shape_of
[
r
]
else
:
# If no info is known on r's shape, use other_shape
self
.
set_shape
(
r
,
other_shape
)
return
if
(
other_r
.
owner
and
r
.
owner
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
op
==
r
.
owner
.
op
):
# We are doing a merge, so the two shape graphs will be the
# same. This is only done so that we call `ancestors` less
# frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape
=
[]
for
i
,
ps
in
enumerate
(
other_shape
):
if
r_shape
is
None
and
other_shape
:
merged_shape
.
append
(
other_shape
[
i
])
elif
(
ps
.
owner
and
isinstance
(
getattr
(
ps
.
owner
,
"op"
,
None
),
Shape_i
)
and
ps
.
owner
.
op
.
i
==
i
and
ps
.
owner
.
inputs
[
0
]
in
(
r
,
other_r
)
):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
r_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
r_shape
[
i
])
elif
isinstance
(
other_shape
[
i
],
(
Constant
,
int
)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape
.
append
(
other_shape
[
i
])
elif
other_shape
[
i
]
==
r_shape
[
i
]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape
.
append
(
r_shape
[
i
])
elif
r_shape
[
i
]
in
ancestors
([
other_shape
[
i
]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
# is strictly more complex. Such a substitution could also lead
# to cycles: if (in the future) r_shape[i] gets replaced by an
# expression of other_shape[i], other_shape[i] may end up
# depending on itself.
merged_shape
.
append
(
r_shape
[
i
])
else
:
merged_shape
.
append
(
other_shape
[
i
])
assert
all
(
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
i
]
and
not
other_r
.
type
.
broadcastable
[
i
]
)
or
self
.
lscalar_one
.
equals
(
merged_shape
[
i
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
merged_shape
[
i
],
only_process_constants
=
True
)
)
for
i
in
range
(
r
.
type
.
ndim
)
)
self
.
shape_of
[
r
]
=
tuple
(
merged_shape
)
for
sv
in
self
.
shape_of
[
r
]:
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
set_shape_i
(
self
,
r
,
i
,
s_i
):
"""Replace element i of shape_of[r] by s_i"""
assert
r
in
self
.
shape_of
prev_shape
=
self
.
shape_of
[
r
]
# prev_shape is a tuple, so we cannot change it inplace,
# so we build another one.
new_shape
=
[]
for
j
,
s_j
in
enumerate
(
prev_shape
):
if
j
==
i
:
new_shape
.
append
(
self
.
unpack
(
s_i
,
r
))
else
:
new_shape
.
append
(
s_j
)
assert
all
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
idx
]
or
self
.
lscalar_one
.
equals
(
new_shape
[
idx
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
new_shape
[
idx
]))
for
idx
in
range
(
r
.
type
.
ndim
)
)
self
.
shape_of
[
r
]
=
tuple
(
new_shape
)
for
sv
in
self
.
shape_of
[
r
]:
self
.
shape_of_reverse_index
.
setdefault
(
sv
,
set
())
.
add
(
r
)
def
init_r
(
self
,
r
):
"""Register r's shape in the shape_of dictionary."""
if
r
not
in
self
.
shape_of
:
self
.
set_shape
(
r
,
self
.
shape_tuple
(
r
))
def
make_vector_shape
(
self
,
r
):
return
as_tensor_variable
(
self
.
shape_of
[
r
],
ndim
=
1
,
dtype
=
"int64"
)
def
on_attach
(
self
,
fgraph
):
if
hasattr
(
fgraph
,
"shape_feature"
):
raise
AlreadyThere
(
"This FunctionGraph already has a ShapeFeature"
)
if
hasattr
(
self
,
"fgraph"
)
and
self
.
fgraph
!=
fgraph
:
raise
Exception
(
"This ShapeFeature is already attached to a graph"
)
self
.
fgraph
=
fgraph
fgraph
.
shape_feature
=
self
# Must be local to the object as otherwise we reuse the same
# variable for multiple fgraph!
self
.
lscalar_one
=
constant
(
1
,
dtype
=
"int64"
)
assert
self
.
lscalar_one
.
type
.
dtype
==
"int64"
self
.
fgraph
=
fgraph
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
shape_of
=
{}
# Variable ->
self
.
scheduled
=
{}
# shape var -> graph v
self
.
shape_of_reverse_index
=
{}
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
,
reason
=
"on_attach"
)
def
on_detach
(
self
,
fgraph
):
self
.
shape_of
=
{}
self
.
scheduled
=
{}
self
.
shape_of_reverse_index
=
{}
self
.
fgraph
=
None
del
fgraph
.
shape_feature
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
shape_of
:
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
assert
r
in
self
.
shape_of
return
for
i
,
r
in
enumerate
(
node
.
inputs
):
# make sure we have shapes for the inputs
self
.
init_r
(
r
)
o_shapes
=
self
.
get_node_infer_shape
(
node
)
# this is packed information
# an element of o_shapes is either None or a tuple
# elements of the tuple can be either strings, or ints
if
len
(
o_shapes
)
!=
len
(
node
.
outputs
):
raise
Exception
(
(
f
'The infer_shape method for the Op "{node.op}" returned a list '
f
"with the wrong number of element: len(o_shapes) = {len(o_shapes)} "
f
" != len(node.outputs) = {len(node.outputs)}"
)
)
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor` rewrite does not fail.
for
sh_idx
,
sh
in
enumerate
(
o_shapes
):
if
sh
is
None
:
continue
if
not
isinstance
(
sh
,
(
list
,
tuple
)):
raise
ValueError
(
f
"infer_shape of {node} didn't return a list of"
f
" list. It returned '{o_shapes}'"
)
new_shape
=
[]
for
i
,
d
in
enumerate
(
sh
):
# Note: we ignore any shape element that is not typed (i.e.,
# does not have a 'dtype' attribute). This means there may
# still remain int elements that are int32 on 32-bit platforms,
# but this works with `local_useless_subtensor`, so for now we
# keep it this way. See #266 for a better long-term fix.
if
getattr
(
d
,
"dtype"
,
"int64"
)
!=
"int64"
:
assert
d
.
dtype
in
discrete_dtypes
,
(
node
,
d
.
dtype
)
assert
str
(
d
.
dtype
)
!=
"uint64"
,
node
new_shape
+=
sh
[
len
(
new_shape
)
:
i
+
1
]
if
isinstance
(
d
,
Constant
):
casted_d
=
constant
(
d
.
data
,
dtype
=
"int64"
)
else
:
casted_d
=
cast
(
d
,
"int64"
)
new_shape
[
i
]
=
casted_d
if
new_shape
:
# We replace the shape with wrong dtype by the one with
# 'int64'.
new_shape
+=
sh
[
len
(
new_shape
)
:]
o_shapes
[
sh_idx
]
=
tuple
(
new_shape
)
for
r
,
s
in
zip
(
node
.
outputs
,
o_shapes
):
self
.
set_shape
(
r
,
s
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
if
new_r
not
in
self
.
shape_of
:
# It happen that the fgraph didn't called on_import for some
# new_r. This happen when new_r don't have an
# owner(i.e. it is a constant or an input of the graph)
# update_shape suppose that r and new_r are in shape_of.
self
.
init_r
(
new_r
)
# This tells us that r and new_r must have the same shape if
# we didn't know that the shapes are related, now we do.
self
.
update_shape
(
new_r
,
r
)
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
# In case 1, if r has a shape_i client, we will want to
# replace the shape_i of r with the shape of new_r. Say that
# r is *scheduled*.
# At that point, node is no longer a client of r, but of new_r
for
(
shpnode
,
idx
)
in
fgraph
.
clients
[
r
]
+
[(
node
,
i
)]:
if
isinstance
(
getattr
(
shpnode
,
"op"
,
None
),
Shape_i
):
idx
=
shpnode
.
op
.
i
repl
=
self
.
shape_of
[
new_r
][
idx
]
if
repl
.
owner
is
shpnode
:
# This mean the replacement shape object is
# exactly the same as the current shape object. So
# no need for replacement.
continue
if
(
repl
.
owner
and
repl
.
owner
.
inputs
[
0
]
is
shpnode
.
inputs
[
0
]
and
isinstance
(
repl
.
owner
.
op
,
Shape_i
)
and
repl
.
owner
.
op
.
i
==
shpnode
.
op
.
i
):
# The replacement is a shape_i of the same
# input. So no need to do this equivalent
# replacement.
continue
if
shpnode
.
outputs
[
0
]
in
ancestors
([
repl
]):
raise
InconsistencyError
(
"This substitution would insert a cycle in the graph:"
f
"node: {node}, i: {i}, r: {r}, new_r: {new_r}"
)
self
.
scheduled
[
shpnode
]
=
new_r
# In case 2, if r is a variable that we've scheduled for shape update,
# then we should cancel it.
unscheduled
=
[
k
for
k
,
v
in
self
.
scheduled
.
items
()
if
v
==
r
]
for
k
in
unscheduled
:
del
self
.
scheduled
[
k
]
# In either case, r could be in shape_of.values(), that is, r itself
# is the shape of something. In that case, we want to update
# the value in shape_of, to keep it up-to-date.
for
v
in
self
.
shape_of_reverse_index
.
get
(
r
,
[]):
# The reverse index is only approximate. It is not updated on
# deletion of variables, or on change_input so it might be the
# case that there are a few extra `v`'s in it that no longer have
# a shape of r or possibly have been deleted from shape_of
# entirely. The important thing is that it permits to recall
# all variables with r in their shape.
for
ii
,
svi
in
enumerate
(
self
.
shape_of
.
get
(
v
,
[])):
if
svi
==
r
:
self
.
set_shape_i
(
v
,
ii
,
new_r
)
self
.
shape_of_reverse_index
[
r
]
=
set
()
def
same_shape
(
self
,
x
:
Variable
,
y
:
Variable
,
dim_x
:
Optional
[
int
]
=
None
,
dim_y
:
Optional
[
int
]
=
None
,
)
->
bool
:
"""Return ``True`` if `x` and `y` have the same shape.
Parameters
==========
x
The `Variable` for which its shape is to be compared with `y`'s shape.
y
The `Variable` for which its shape is to be compared with `x`'s shape.
dim_x
If non ``None``, compare only the dimension of `x` equal to
`dim_x`.
dim_y
If non ``None``, compare only the dimension of `y` equal to
`dim_y`.
"""
sx
=
self
.
shape_of
[
x
]
sy
=
self
.
shape_of
[
y
]
if
sx
is
None
or
sy
is
None
:
return
False
if
dim_x
is
not
None
:
sx
=
[
sx
[
dim_x
]]
if
dim_y
is
not
None
:
sy
=
[
sy
[
dim_y
]]
if
len
(
sx
)
!=
len
(
sy
):
return
False
# Canonicalize the graphs so that comparisons are reasonable
# TODO FIXME: This should *not* need to be performed manually here.
# Instead, the shape information in `self.shape_of` should be operated
# upon alongside all the other elements in a `FunctionGraph` (e.g. as
# if `self.shape_of.values()` were additional outputs).
shapes_fg
=
FunctionGraph
(
outputs
=
sx
+
sy
,
# features=[self],
clone
=
True
,
# copy_inputs=False,
)
from
aesara.graph.rewriting.utils
import
rewrite_graph
canon_shapes_fg
=
type_cast
(
FunctionGraph
,
rewrite_graph
(
shapes_fg
,
custom_rewrite
=
topo_constant_folding
),
)
canon_shapes
=
canon_shapes_fg
.
outputs
sx
=
canon_shapes
[:
len
(
sx
)]
sy
=
canon_shapes
[
len
(
sx
)
:]
for
dx
,
dy
in
zip
(
sx
,
sy
):
if
not
equal_computations
([
dx
],
[
dy
]):
return
False
return
True
def
clone
(
self
):
return
type
(
self
)()
class
ShapeOptimizer
(
GraphRewriter
):
"""Rewriter that adds `ShapeFeature` as a feature."""
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ShapeFeature
())
def
apply
(
self
,
fgraph
):
pass
class
UnShapeOptimizer
(
GraphRewriter
):
"""Rewriter that removes `ShapeFeature` as a feature."""
def
apply
(
self
,
fgraph
):
for
feature
in
fgraph
.
_features
:
if
isinstance
(
feature
,
ShapeFeature
):
fgraph
.
remove_feature
(
feature
)
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
aesara
.
compile
.
mode
.
optdb
.
register
(
"ShapeOpt"
,
ShapeOptimizer
(),
"fast_run"
,
"fast_compile"
,
position
=
0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
aesara
.
compile
.
mode
.
optdb
.
register
(
"UnShapeOpt"
,
UnShapeOptimizer
(),
position
=
10
)
def
local_reshape_chain
(
op
):
@node_rewriter
([
op
])
def
f
(
fgraph
,
node
):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if
not
check_chain
(
node
,
op
,
op
):
return
False
# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
],
node
.
inputs
[
1
])
# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace
(
node
.
outputs
,
rval
)
# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if
rval
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
return
[
rval
]
else
:
return
False
return
f
register_canonicalize
(
local_reshape_chain
(
Reshape
),
name
=
"local_reshape_chain"
)
@register_useless
@register_canonicalize
@register_stabilize
@node_rewriter
([
Reshape
])
def
local_useless_reshape
(
fgraph
,
node
):
"""
Remove two kinds of useless reshape.
Remove Reshape when both the input and output have a single dimension.
Remove Reshape when reshaping to the shape of the input.
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Reshape
):
return
False
inp
=
node
.
inputs
[
0
]
output
=
node
.
outputs
[
0
]
output_shape
=
node
.
inputs
[
1
]
if
inp
.
ndim
!=
output
.
ndim
:
return
False
# Simple case: both input and output have a single dimension.
# This could hide errors if the user provides inconsistent shapes.
if
inp
.
ndim
==
1
and
output
.
ndim
==
1
and
inp
.
broadcastable
==
output
.
broadcastable
:
return
[
inp
]
# Second case: all the shapes match the input shape
# Match Reshape(x, x.shape)
if
output_shape
.
owner
and
isinstance
(
output_shape
.
owner
.
op
,
Shape
):
shape_input
=
output_shape
.
owner
.
inputs
[
0
]
if
shape_input
==
inp
:
return
[
inp
]
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if
output_shape
.
owner
and
isinstance
(
output_shape
.
owner
.
op
,
MakeVector
):
output_shape_is
=
output_shape
.
owner
.
inputs
shape_feature
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
nb_m1
=
0
shape_match
=
[
False
]
*
inp
.
ndim
for
dim
in
range
(
inp
.
ndim
):
outshp_i
=
output_shape_is
[
dim
]
# Match Shape_i{dim}(input)
if
(
outshp_i
.
owner
and
isinstance
(
outshp_i
.
owner
.
op
,
Shape_i
)
and
outshp_i
.
owner
.
op
.
i
==
dim
and
outshp_i
.
owner
.
inputs
[
0
]
==
inp
):
shape_match
[
dim
]
=
True
continue
# Match Shape(input)[dim]
if
(
outshp_i
.
owner
and
isinstance
(
outshp_i
.
owner
.
op
,
Subtensor
)
and
len
(
outshp_i
.
owner
.
inputs
)
==
2
and
extract_constant
(
outshp_i
.
owner
.
inputs
[
1
])
==
dim
):
subtensor_inp
=
outshp_i
.
owner
.
inputs
[
0
]
if
subtensor_inp
.
owner
and
isinstance
(
subtensor_inp
.
owner
.
op
,
Shape
):
shape_input_i
=
subtensor_inp
.
owner
.
inputs
[
0
]
if
shape_input_i
==
inp
:
shape_match
[
dim
]
=
True
continue
# Match 1 if input.broadcastable[dim] is True
cst_outshp_i
=
extract_constant
(
outshp_i
,
only_process_constants
=
1
)
if
inp
.
broadcastable
[
dim
]
and
cst_outshp_i
==
1
:
shape_match
[
dim
]
=
True
continue
# Match -1
if
cst_outshp_i
==
-
1
:
shape_match
[
dim
]
=
True
nb_m1
+=
1
continue
# Match shape_of[input][dim] or its constant equivalent
if
shape_feature
:
inpshp_i
=
shape_feature
.
get_shape
(
inp
,
dim
)
if
inpshp_i
==
outshp_i
or
(
extract_constant
(
inpshp_i
,
only_process_constants
=
1
)
==
extract_constant
(
outshp_i
,
only_process_constants
=
1
)
):
shape_match
[
dim
]
=
True
continue
if
all
(
shape_match
)
and
nb_m1
<=
1
:
return
[
inp
]
# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return
False
@register_canonicalize
@node_rewriter
([
Reshape
])
def
local_reshape_to_dimshuffle
(
fgraph
,
node
):
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
The goal is to avoid using reshape to add or remove broadcastable
dimensions, but use dimshuffle instead, so dimshuffles can cancel out
or be removed later on.
For example:
- reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,))
- reshape(x, (1, m, 1, n, 1, 1))
--> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n)))
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Reshape
):
return
False
inp
=
node
.
inputs
[
0
]
output
=
node
.
outputs
[
0
]
output_shape
=
node
.
inputs
[
1
]
dimshuffle_new_order
=
[]
new_output_shape
=
[]
index
=
0
# index over the output of the new reshape
for
i
in
range
(
output
.
ndim
):
# Since output_shape is a symbolic vector, we trust extract_constant
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim
=
extract_constant
(
output_shape
[
i
],
only_process_constants
=
False
,
elemwise
=
False
)
if
dim
==
1
:
dimshuffle_new_order
.
append
(
"x"
)
else
:
dimshuffle_new_order
.
append
(
index
)
new_output_shape
.
append
(
dim
)
index
=
index
+
1
if
index
!=
output
.
ndim
:
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
copy_stack_trace
(
output
,
inner
)
new_node
=
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inner
)]
copy_stack_trace
(
output
,
new_node
)
return
new_node
@register_canonicalize
@register_stabilize
@node_rewriter
([
Reshape
])
def
local_reshape_lift
(
fgraph
,
node
):
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Notes
-----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
"""
if
(
isinstance
(
node
.
op
,
Reshape
)
and
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Elemwise
)
and
len
(
node
.
inputs
[
0
]
.
owner
.
inputs
)
==
1
):
r
=
node
.
op
(
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
],
node
.
inputs
[
1
])
# Copy stacktrace from previous Reshape op, as an error in new
# Reshape op could only have been caused by old one.
copy_stack_trace
(
node
.
outputs
,
r
)
e
=
node
.
inputs
[
0
]
.
owner
.
op
(
r
)
# Copy stacktrace from both previous Reshape and UnaryElemwise op
# because an error in new cg could have been caused by either ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
e
)
return
[
e
]
register_canonicalize
(
RemovalNodeRewriter
(
tensor_copy
),
name
=
"remove_tensor_copy"
)
@register_useless
@register_canonicalize
@node_rewriter
([
SpecifyShape
])
def
local_merge_consecutive_specify_shape
(
fgraph
,
node
):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
"""
if
not
isinstance
(
node
.
op
,
SpecifyShape
):
return
False
obj
=
node
.
inputs
[
0
]
if
not
(
obj
.
owner
and
isinstance
(
obj
.
owner
.
op
,
SpecifyShape
)):
return
False
inner_obj
,
*
shape
=
obj
.
owner
.
inputs
for
dim
,
sh
in
enumerate
(
node
.
inputs
[
1
:]):
if
not
NoneConst
.
equals
(
sh
):
shape
[
dim
]
=
sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
# the same.
return
[
specify_shape
(
inner_obj
,
shape
)]
@register_useless
@register_canonicalize
@node_rewriter
([
Shape
])
def
local_Shape_of_SpecifyShape
(
fgraph
,
node
):
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
if
not
isinstance
(
node
.
op
,
Shape
):
return
False
specified_shape
=
node
.
inputs
[
0
]
if
not
isinstance
(
getattr
(
specified_shape
.
owner
,
"op"
,
None
),
SpecifyShape
):
return
False
x
,
*
shape
=
specified_shape
.
owner
.
inputs
# Replace `NoneConst` by `shape_i`
for
i
,
sh
in
enumerate
(
shape
):
if
NoneConst
.
equals
(
sh
):
shape
[
i
]
=
shape_i
(
x
,
i
,
fgraph
)
return
[
stack
(
shape
)
.
astype
(
np
.
int64
)]
@register_useless
@register_canonicalize
@node_rewriter
([
Shape_i
])
def
local_Shape_i_of_broadcastable
(
fgraph
,
node
):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
if
not
isinstance
(
node
.
op
,
Shape_i
):
return
False
shape_arg
=
node
.
inputs
[
0
]
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
if
shape_arg
.
broadcastable
[
node
.
op
.
i
]:
return
[
as_tensor_variable
(
1
,
dtype
=
np
.
int64
)]
@register_specialize
@register_canonicalize
@node_rewriter
([
Shape
])
def
local_shape_to_shape_i
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
Shape
):
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
shape_feature
=
fgraph
.
shape_feature
ret
=
shape_feature
.
make_vector_shape
(
node
.
inputs
[
0
])
# We need to copy over stack trace from input to output
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
@register_specialize
@register_canonicalize
@node_rewriter
([
Shape_i
])
def
local_track_shape_i
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
Shape_i
):
return
False
try
:
shape_feature
=
fgraph
.
shape_feature
except
AttributeError
:
return
False
if
node
not
in
shape_feature
.
scheduled
:
return
False
# Don't unschedule node as it could be reinserted in the
# fgraph as we don't change it in the shapefeature internal
# structure.
replacement
=
shape_feature
.
scheduled
[
node
]
return
[
shape_feature
.
shape_of
[
replacement
][
node
.
op
.
i
]]
@register_canonicalize
@node_rewriter
([
Reshape
])
def
local_useless_dimshuffle_in_reshape
(
fgraph
,
node
):
"""
Removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Reshape
):
return
False
if
not
(
node
.
inputs
[
0
]
.
owner
is
not
None
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
DimShuffle
)
):
return
False
new_order
=
node
.
inputs
[
0
]
.
owner
.
op
.
new_order
inp
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
broadcastables
=
node
.
inputs
[
0
]
.
broadcastable
new_order_of_nonbroadcast
=
[]
for
i
,
bd
in
zip
(
new_order
,
broadcastables
):
if
not
bd
:
new_order_of_nonbroadcast
.
append
(
i
)
no_change_in_order
=
all
(
new_order_of_nonbroadcast
[
i
]
<=
new_order_of_nonbroadcast
[
i
+
1
]
for
i
in
range
(
len
(
new_order_of_nonbroadcast
)
-
1
)
)
if
no_change_in_order
:
shape
=
node
.
inputs
[
1
]
ret
=
op
.
__class__
(
node
.
outputs
[
0
]
.
ndim
)(
inp
,
shape
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
Unbroadcast
])
def
local_useless_unbroadcast
(
fgraph
,
node
):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
TODO: Implement equivalent rewrite for SpecifyShape
"""
if
isinstance
(
node
.
op
,
Unbroadcast
):
x
=
node
.
inputs
[
0
]
if
x
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return
[
x
]
else
:
# Keep the flags that modify something
new_axes
=
tuple
(
ax
for
ax
in
node
.
op
.
axes
if
x
.
type
.
shape
[
ax
]
==
1
)
if
new_axes
==
node
.
op
.
axes
:
# All flags are useful
return
None
else
:
r
=
unbroadcast
(
x
,
*
new_axes
)
# Copy over stacktrace from previous output
copy_stack_trace
(
node
.
outputs
,
r
)
return
[
r
]
@register_canonicalize
@register_specialize
@node_rewriter
([
Unbroadcast
])
def
local_unbroadcast_lift
(
fgraph
,
node
):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Unbroadcast
):
return
False
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
len
(
inode
.
inputs
)
==
1
:
if
len
(
fgraph
.
clients
.
get
(
inp
,
()))
==
1
:
unbroadcasted
=
unbroadcast
(
inode
.
inputs
[
0
],
*
op
.
axes
)
copy_stack_trace
(
node
.
outputs
,
unbroadcasted
)
rval
=
inode
.
op
.
make_node
(
unbroadcasted
)
.
outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
if
inode
and
isinstance
(
inode
.
op
,
Unbroadcast
):
# Merge axis of each unbroadcast
axis
=
tuple
(
set
(
inode
.
op
.
axes
)
.
union
(
set
(
op
.
axes
)))
iinput
=
inode
.
inputs
[
0
]
rval
=
[
unbroadcast
(
iinput
,
*
axis
)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
aesara/tensor/utils.py
浏览文件 @
63f52536
...
...
@@ -63,7 +63,9 @@ def shape_of_variables(fgraph, input_shapes):
"""
if
not
hasattr
(
fgraph
,
"shape_feature"
):
fgraph
.
attach_feature
(
aesara
.
tensor
.
rewriting
.
basic
.
ShapeFeature
())
from
aesara.tensor.rewriting.shape
import
ShapeFeature
fgraph
.
attach_feature
(
ShapeFeature
())
input_dims
=
[
dimension
...
...
tests/compile/test_builders.py
浏览文件 @
63f52536
...
...
@@ -21,7 +21,7 @@ from aesara.tensor.math import round as at_round
from
aesara.tensor.math
import
sigmoid
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.random.utils
import
RandomStream
from
aesara.tensor.rewriting.
basic
import
ShapeOptimizer
from
aesara.tensor.rewriting.
shape
import
ShapeOptimizer
from
aesara.tensor.shape
import
specify_shape
from
aesara.tensor.type
import
TensorType
,
matrices
,
matrix
,
scalar
,
vector
,
vectors
from
tests
import
unittest_tools
...
...
tests/tensor/random/test_basic.py
浏览文件 @
63f52536
...
...
@@ -55,7 +55,7 @@ from aesara.tensor.random.basic import (
wald
,
weibull
,
)
from
aesara.tensor.rewriting.
basic
import
ShapeFeature
from
aesara.tensor.rewriting.
shape
import
ShapeFeature
from
aesara.tensor.type
import
iscalar
,
scalar
,
tensor
from
tests.unittest_tools
import
create_aesara_param
...
...
tests/tensor/rewriting/test_basic.py
浏览文件 @
63f52536
import
contextlib
import
copy
import
numpy
as
np
...
...
@@ -10,20 +9,15 @@ import aesara.tensor as at
from
aesara
import
shared
from
aesara.compile
import
optdb
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
OPT_NONE
,
Mode
,
get_default_mode
,
get_mode
from
aesara.compile.mode
import
get_default_mode
,
get_mode
from
aesara.compile.ops
import
DeepCopyOp
,
deep_copy_op
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.rewriting.basic
import
check_stack_trace
,
node_rewriter
,
out2in
from
aesara.graph.rewriting.basic
import
check_stack_trace
,
out2in
from
aesara.graph.rewriting.db
import
RewriteDatabaseQuery
from
aesara.graph.rewriting.utils
import
rewrite_graph
from
aesara.graph.type
import
Type
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.printing
import
pprint
from
aesara.raise_op
import
Assert
,
CheckAndRaise
from
aesara.scalar.basic
import
Composite
from
aesara.tensor.basic
import
(
Alloc
,
Join
,
...
...
@@ -31,21 +25,15 @@ from aesara.tensor.basic import (
ScalarFromTensor
,
Split
,
TensorFromScalar
,
alloc
,
as_tensor_variable
,
join
,
second
,
tile
,
)
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
,
repeat
,
unique
from
aesara.tensor.math
import
(
add
,
bitwise_and
,
bitwise_or
,
bitwise_xor
,
cos
,
cosh
,
dot
,
eq
,
exp
,
...
...
@@ -53,46 +41,32 @@ from aesara.tensor.math import (
ge
,
gt
,
int_div
,
invert
,
iround
,
le
,
log
,
log2
,
log10
,
lt
,
maximum
,
minimum
,
mul
,
neg
,
neq
,
)
from
aesara.tensor.math
import
pow
as
at_pow
from
aesara.tensor.math
import
reciprocal
from
aesara.tensor.math
import
round
as
at_round
from
aesara.tensor.math
import
sin
,
sinh
,
softplus
,
sqr
,
sqrt
,
sub
from
aesara.tensor.math
import
softplus
,
sqrt
,
sub
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
t
an
,
tanh
,
true_div
,
xor
from
aesara.tensor.math
import
t
rue_div
from
aesara.tensor.rewriting.basic
import
(
ShapeFeature
,
assert_op
,
local_alloc_sink_dimshuffle
,
local_dimshuffle_lift
,
local_merge_alloc
,
local_reshape_to_dimshuffle
,
local_useless_alloc
,
local_useless_dimshuffle_in_reshape
,
local_useless_elemwise
,
local_useless_reshape
,
register_specialize
,
)
from
aesara.tensor.rewriting.math
import
local_lift_transpose_through_dot
from
aesara.tensor.rewriting.shape
import
ShapeFeature
from
aesara.tensor.shape
import
(
Reshape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
,
reshape
,
shape
,
specify_shape
,
unbroadcast
,
)
...
...
@@ -102,17 +76,14 @@ from aesara.tensor.subtensor import (
advanced_inc_subtensor
,
advanced_inc_subtensor1
,
inc_subtensor
,
set_subtensor
,
)
from
aesara.tensor.type
import
(
TensorType
,
dmatrices
,
dmatrix
,
dscalar
,
dvector
,
fmatrix
,
fscalar
,
fvector
,
imatrices
,
iscalar
,
iscalars
,
...
...
@@ -129,7 +100,6 @@ from aesara.tensor.type import (
tensor4
,
values_eq_approx_remove_nan
,
vector
,
vectors
,
)
from
tests
import
unittest_tools
as
utt
...
...
@@ -139,8 +109,6 @@ if rewrite_mode == "FAST_COMPILE":
rewrite_mode
=
"FAST_RUN"
rewrite_mode
=
get_mode
(
rewrite_mode
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
_stabilize_rewrites
=
RewriteDatabaseQuery
(
include
=
[
"fast_run"
])
_stabilize_rewrites
.
position_cutoff
=
1.51
_stabilize_rewrites
=
optdb
.
query
(
_stabilize_rewrites
)
...
...
@@ -153,10 +121,6 @@ _fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"])
_fast_run_rewrites
=
optdb
.
query
(
_fast_run_rewrites
)
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
def
rewrite
(
g
,
level
=
"fast_run"
):
if
level
==
"fast_run"
:
_fast_run_rewrites
.
rewrite
(
g
)
...
...
@@ -169,1124 +133,6 @@ def rewrite(g, level="fast_run"):
return
g
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
x
=
TensorType
(
shape
=
xbc
,
dtype
=
"float64"
)(
"x"
)
y
=
TensorType
(
shape
=
ybc
,
dtype
=
"float64"
)(
"y"
)
z
=
TensorType
(
shape
=
zbc
,
dtype
=
"float64"
)(
"z"
)
return
x
,
y
,
z
class
TestDimshuffleLift
:
def
test_double_transpose
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
x
,
(
1
,
0
)),
(
1
,
0
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
(
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
def
test_merge2
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
x
,
(
1
,
"x"
,
0
)),
(
2
,
0
,
"x"
,
1
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
(
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
),
str
(
g
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))"
,
str
(
g
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
"all"
)
def
test_elim3
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
ds
(
x
,
(
0
,
"x"
,
1
)),
(
2
,
0
,
"x"
,
1
)),
(
1
,
0
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
str
(
g
)
==
(
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
),
str
(
g
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
,
str
(
g
)
# no need to check_stack_trace as graph is supposed to be empty
def
test_lift
(
self
):
x
,
y
,
z
=
inputs
([
False
]
*
1
,
[
False
]
*
2
,
[
False
]
*
3
)
e
=
x
+
y
+
z
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert
str
(
g
)
in
(
init_str_g_inplace
,
init_str_g_noinplace
),
str
(
g
)
rewrite_str_g_inplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
rewrite_str_g_noinplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
in
(
rewrite_str_g_inplace
,
rewrite_str_g_noinplace
),
str
(
g
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
"all"
)
def
test_recursive_lift
(
self
):
v
=
vector
(
dtype
=
"float64"
)
m
=
matrix
(
dtype
=
"float64"
)
out
=
((
v
+
42
)
*
(
m
+
84
))
.
T
g
=
FunctionGraph
([
v
,
m
],
[
out
])
init_str_g
=
(
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, (None,))>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, (None, None))>, "
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert
str
(
g
)
==
init_str_g
new_out
=
local_dimshuffle_lift
.
transform
(
g
,
g
.
outputs
[
0
]
.
owner
)[
0
]
new_g
=
FunctionGraph
(
g
.
inputs
,
[
new_out
])
rewrite_str_g
=
(
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, (None, None))>), "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
)
assert
str
(
new_g
)
==
rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
new_g
,
ops_to_check
=
"all"
)
def
test_useless_dimshuffle
(
self
):
x
,
_
,
_
=
inputs
()
e
=
ds
(
x
,
(
0
,
1
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
# Check stacktrace was copied over correctly after rewrite was applied
assert
hasattr
(
g
.
outputs
[
0
]
.
tag
,
"trace"
)
def
test_dimshuffle_on_broadcastable
(
self
):
x
,
y
,
z
=
inputs
([
False
,
True
],
[
True
,
False
,
True
],
[
False
,
False
,
True
])
u
=
at
.
constant
(
1
)
ds_x
=
ds
(
x
,
(
0
,
"x"
))
# useless
ds_y
=
ds
(
y
,
(
2
,
1
,
0
))
# useless
ds_z
=
ds
(
z
,
(
2
,
1
,
0
))
# useful
ds_u
=
ds
(
u
,
(
"x"
))
# useful
g
=
FunctionGraph
([
x
,
y
,
z
,
u
],
[
ds_x
,
ds_y
,
ds_z
,
ds_u
])
assert
(
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift
.
rewrite
(
g
)
assert
(
str
(
g
)
==
"FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
hasattr
(
g
.
outputs
[
0
]
.
tag
,
"trace"
)
def
test_local_useless_dimshuffle_in_reshape
():
vec
=
TensorType
(
shape
=
(
False
,),
dtype
=
"float64"
)(
"vector"
)
mat
=
TensorType
(
shape
=
(
False
,
False
),
dtype
=
"float64"
)(
"mat"
)
row
=
TensorType
(
shape
=
(
True
,
False
),
dtype
=
"float64"
)(
"row"
)
col
=
TensorType
(
shape
=
(
False
,
True
),
dtype
=
"float64"
)(
"col"
)
reshape_dimshuffle_vector
=
reshape
(
vec
.
dimshuffle
(
"x"
,
0
),
vec
.
shape
)
reshape_dimshuffle_mat
=
reshape
(
mat
.
dimshuffle
(
"x"
,
0
,
"x"
,
1
),
mat
.
shape
)
reshape_dimshuffle_row
=
reshape
(
row
.
dimshuffle
(
1
,
"x"
),
row
.
shape
)
reshape_dimshuffle_col
=
reshape
(
col
.
dimshuffle
(
0
),
col
.
shape
)
g
=
FunctionGraph
(
[
vec
,
mat
,
row
,
col
],
[
reshape_dimshuffle_vector
,
reshape_dimshuffle_mat
,
reshape_dimshuffle_row
,
reshape_dimshuffle_col
,
],
)
assert
str
(
g
)
==
(
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape
=
out2in
(
local_useless_dimshuffle_in_reshape
)
useless_dimshuffle_in_reshape
.
rewrite
(
g
)
assert
str
(
g
)
==
(
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col)))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
"all"
)
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2
=
reshape
(
mat
.
dimshuffle
(
"x"
,
1
,
"x"
,
0
),
mat
.
shape
)
h
=
FunctionGraph
([
mat
],
[
reshape_dimshuffle_mat2
])
str_h
=
str
(
h
)
useless_dimshuffle_in_reshape
.
rewrite
(
h
)
assert
str
(
h
)
==
str_h
class
TestFusion
:
rewrites
=
RewriteDatabaseQuery
(
include
=
[
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
"inplace"
,
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
],
)
mode
=
Mode
(
get_default_mode
()
.
linker
,
rewrites
)
_shared
=
staticmethod
(
shared
)
topo_exclude
=
()
def
my_init
(
dtype
=
"float64"
,
num
=
0
):
return
np
.
zeros
((
5
,
5
),
dtype
=
dtype
)
+
num
fw
,
fx
,
fy
,
fz
=
[
tensor
(
dtype
=
"float32"
,
shape
=
[
False
]
*
2
,
name
=
n
)
for
n
in
"wxyz"
]
dw
,
dx
,
dy
,
dz
=
[
tensor
(
dtype
=
"float64"
,
shape
=
[
False
]
*
2
,
name
=
n
)
for
n
in
"wxyz"
]
ix
,
iy
,
iz
=
[
tensor
(
dtype
=
"int32"
,
shape
=
[
False
]
*
2
,
name
=
n
)
for
n
in
"xyz"
]
fv
=
fvector
(
"v"
)
fs
=
fscalar
(
"s"
)
fwv
=
my_init
(
"float32"
,
1
)
fxv
=
my_init
(
"float32"
,
2
)
fyv
=
my_init
(
"float32"
,
3
)
fzv
=
my_init
(
"float32"
,
4
)
fvv
=
_asarray
(
np
.
random
.
random
(
5
),
dtype
=
"float32"
)
fsv
=
np
.
asarray
(
np
.
random
.
random
(),
dtype
=
"float32"
)
dwv
=
my_init
(
"float64"
,
5
)
ixv
=
_asarray
(
my_init
(
num
=
60
),
dtype
=
"int32"
)
iyv
=
_asarray
(
my_init
(
num
=
70
),
dtype
=
"int32"
)
izv
=
_asarray
(
my_init
(
num
=
70
),
dtype
=
"int32"
)
fwx
=
fw
+
fx
ftanx
=
tan
(
fx
)
@pytest.mark.parametrize
(
"case"
,
[
(
fx
+
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
fzv
,
"float32"
,
),
# 0
(
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
*
fzv
,
"float32"
,
),
# 1
(
fx
+
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
*
fzv
,
"float32"
,
),
# 2
(
fx
*
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
+
fzv
,
"float32"
,
),
# 3
(
fw
+
fx
+
fy
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
# 5
(
((
fw
+
fx
)
+
fy
)
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
(
fx
+
fy
))
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
(
fx
+
fy
)
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
fw
+
(
fx
+
(
fy
+
fz
)),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
# 10
(
fw
*
fx
*
fy
*
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
*
fxv
*
fyv
*
fzv
,
"float32"
,
),
(
fw
+
fx
*
fy
*
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
*
fyv
*
fzv
,
"float32"
,
),
(
fx
+
fy
*
fz
*
fx
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
*
fzv
*
fxv
,
"float32"
,
),
(
fx
*
fy
+
fz
+
fy
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
+
fzv
+
fyv
,
"float32"
,
),
(
fx
*
fy
*
fz
*
fw
+
fx
+
fy
+
fz
+
fw
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
*
fzv
*
fwv
+
fxv
+
fyv
+
fzv
+
fwv
,
"float32"
,
),
# 15
# test with constant
(
(
fw
+
fx
)
+
(
fy
+
fz
)
+
2.0
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
((
fw
+
fx
)
+
2.0
+
fy
)
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
(
fw
+
(
fx
+
2.0
+
fy
))
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
(
fw
+
(
fx
+
fy
)
+
2
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
fw
+
(
fx
+
(
fy
+
fz
)
+
2.0
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
# 20
(
2
+
(
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
# mix float32 and float64
(
2
+
(
dw
+
fx
)
+
(
fy
+
fz
),
(
dw
,
fx
,
fy
,
fz
),
(
dwv
,
fxv
,
fyv
,
fzv
),
1
,
dwv
+
fxv
+
fyv
+
fzv
+
2
,
"float64"
,
),
(
2
+
(
fw
+
dw
)
+
(
fy
+
fz
),
(
fw
,
dw
,
fy
,
fz
),
(
fwv
,
dwv
,
fyv
,
fzv
),
1
,
fwv
+
dwv
+
fyv
+
fzv
+
2
,
"float64"
,
),
(
2
+
(
fw
+
fx
)
+
(
dw
+
fz
),
(
fw
,
fx
,
dw
,
fz
),
(
fwv
,
fxv
,
dwv
,
fzv
),
1
,
fwv
+
fxv
+
dwv
+
fzv
+
2
,
"float64"
,
),
(
2
+
(
fw
+
fx
)
+
(
fy
+
dw
),
(
fw
,
fx
,
fy
,
dw
),
(
fwv
,
fxv
,
fyv
,
dwv
),
1
,
fwv
+
fxv
+
fyv
+
dwv
+
2
,
"float64"
,
),
# 25
# test when their is other op then elemwise.
(
(
fwx
.
sum
())
+
(
fwx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
4
,
(
fwv
+
fxv
)
.
sum
()
+
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
# test other elemwise op
(
fx
+
fy
+
cos
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
cos
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
cosh
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
cosh
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
abs
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
absolute
(
fzv
),
"float32"
,
),
(
ix
+
iy
+
abs
(
iz
),
(
ix
,
iy
,
iz
),
(
ixv
,
iyv
,
izv
),
1
,
ixv
+
iyv
+
np
.
absolute
(
izv
),
"int32"
,
),
# 30
(
fx
+
fy
+
log
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
log
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
log2
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
log2
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
log10
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
log10
(
fzv
),
"float32"
,
),
(
fx
+
fy
**
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
**
fzv
,
"float32"
,
),
# pow
(
fx
+
fy
+
exp
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
exp
(
fzv
),
"float32"
,
),
# 35
(
fx
-
fy
-
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
-
fzv
,
"float32"
,
),
(
fx
-
(
fy
/
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
/
fzv
),
"float32"
,
),
(
fx
-
true_div
(
fy
,
2
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
fxv
-
(
fyv
/
2
),
"float32"
,
),
(
fx
-
true_div
(
fy
,
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
/
fzv
),
"float32"
,
),
(
fx
-
int_div
(
ix
*
100
,
iy
*
1000
),
(
fx
,
ix
,
iy
),
(
fxv
,
ixv
,
iyv
),
1
,
fxv
-
((
ixv
*
100
)
//
(
iyv
*
1000
)),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
# 40
(
fx
-
(
fy
/
2
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
fxv
-
(
fyv
/
2
),
"float32"
),
(
fx
-
(
fy
%
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
%
fzv
),
"float32"
,
),
(
fx
-
(
fy
>
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
>
fzv
),
"float32"
,
),
(
fx
-
(
fy
>=
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
>=
fzv
),
"float32"
,
),
(
fx
-
(
fy
<
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
<
fzv
),
"float32"
,
),
# 45
(
fx
-
(
fy
<=
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
<=
fzv
),
"float32"
,
),
(
fx
-
eq
(
fy
,
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
==
fzv
),
"float32"
,
),
(
fx
-
neq
(
fy
,
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
!=
fzv
),
"float32"
,
),
(
fx
-
fy
+
tan
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
tan
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
tanh
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
tanh
(
fzv
),
"float32"
,
),
# 50
(
fx
-
fy
+
sin
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
sin
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
sinh
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
sinh
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
sqr
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
(
fzv
*
fzv
),
"float32"
,
),
(
fx
-
fy
+
sqrt
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
sqrt
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
reciprocal
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
(
1
/
fzv
),
"float32"
,
),
# 55
(
fx
-
fy
+
neg
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
(
-
fzv
),
"float32"
,
),
(
fx
-
fy
+
at_round
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
round
(
fzv
),
"float32"
,
),
(
ix
-
iy
+
iround
(
fz
),
(
ix
,
iy
,
fz
),
(
ixv
,
iyv
,
fzv
),
1
,
ixv
-
iyv
+
np
.
round
(
fzv
),
"int64"
,
),
# Bit op
(
fx
-
bitwise_or
(
iy
,
iz
),
(
fx
,
iy
,
iz
),
(
fxv
,
iyv
,
izv
),
1
,
fxv
-
(
iyv
|
izv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
(
fx
-
xor
(
iy
,
iz
),
(
fx
,
iy
,
iz
),
(
fxv
,
iyv
,
izv
),
1
,
fxv
-
(
iyv
^
izv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
# 60
(
fx
-
bitwise_and
(
iy
,
iz
),
(
fx
,
iy
,
iz
),
(
fxv
,
iyv
,
izv
),
1
,
fxv
-
(
iyv
&
izv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
(
fx
-
invert
(
iy
),
(
fx
,
iy
),
(
fxv
,
iyv
),
1
,
fxv
-
(
~
iyv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
(
fx
-
at
.
cast
(
fy
,
dtype
=
"float64"
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
fxv
-
np
.
asarray
(
fyv
,
"float64"
),
"float64"
,
),
(
at_pow
(
fx
*
fy
+
fz
,
fx
*
fy
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
np
.
power
(
fxv
*
fyv
+
fzv
,
fxv
*
fyv
),
"float32"
,
),
(
fv
+
fy
**
fz
,
(
fv
,
fy
,
fz
),
(
fvv
,
fyv
,
fzv
),
2
,
fvv
+
fyv
**
fzv
,
"float32"
,
),
# fused with a dimshuffle #65
(
fv
-
fy
+
tanh
(
fz
),
(
fv
,
fy
,
fz
),
(
fvv
,
fyv
,
fzv
),
2
,
fvv
-
fyv
+
np
.
tanh
(
fzv
),
"float32"
,
),
# fused with a dimshuffle
# Cases where the same input is reused many times.
(
mul
(
fx
,
fx
,
fx
,
fx
),
(
fx
,),
(
fxv
,),
1
,
fxv
*
fxv
*
fxv
*
fxv
,
"float32"
,
),
(
mul
(
fx
,
ftanx
,
ftanx
),
(
fx
,),
(
fxv
,),
1
,
fxv
*
np
.
tan
(
fxv
)
*
np
.
tan
(
fxv
),
"float32"
,
),
(
mul
(
fx
,
ftanx
,
ftanx
,
fx
),
(
fx
,),
(
fxv
,),
1
,
fxv
*
np
.
tan
(
fxv
)
*
np
.
tan
(
fxv
)
*
fxv
,
"float32"
,
),
(
mul
(
ftanx
,
ftanx
,
fx
+
fy
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
np
.
tan
(
fxv
)
*
np
.
tan
(
fxv
)
*
(
fxv
+
fyv
),
"float32"
,
),
# 70
# Cases with different broadcast pattern. They should not
# be merged as this would duplicate computation
# The graph should have 2 elemwise and 1 dimshuffle
(
fx
*
sin
(
fs
),
(
fx
,
fs
),
(
fxv
,
fsv
),
3
,
fxv
*
np
.
sin
(
fsv
),
"float32"
,
),
],
)
def
test_elemwise_fusion
(
self
,
case
,
nb_repeat
=
1
,
assert_len_topo
=
True
):
"""Verify that `Elemwise` fusion works."""
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
answer
,
out_dtype
=
case
if
isinstance
(
out_dtype
,
dict
):
out_dtype
=
out_dtype
[
config
.
cast_policy
]
if
self
.
_shared
is
None
:
f
=
function
(
list
(
sym_inputs
),
g
,
mode
=
self
.
mode
)
for
x
in
range
(
nb_repeat
):
out
=
f
(
*
val_inputs
)
else
:
out
=
self
.
_shared
(
np
.
zeros
((
5
,
5
),
dtype
=
out_dtype
),
"out"
)
assert
out
.
dtype
==
g
.
dtype
f
=
function
(
sym_inputs
,
[],
updates
=
[(
out
,
g
)],
mode
=
self
.
mode
)
for
x
in
range
(
nb_repeat
):
f
(
*
val_inputs
)
out
=
out
.
get_value
()
atol
=
1e-8
if
out_dtype
==
"float32"
:
atol
=
1e-6
assert
np
.
allclose
(
out
,
answer
*
nb_repeat
,
atol
=
atol
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo_
=
[
n
for
n
in
topo
if
not
isinstance
(
n
.
op
,
self
.
topo_exclude
)]
if
assert_len_topo
:
assert
len
(
topo_
)
==
nb_elemwise
if
nb_elemwise
==
1
:
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if
len
(
set
(
g
.
owner
.
inputs
))
==
len
(
g
.
owner
.
inputs
):
expected_len_sym_inputs
=
sum
(
not
isinstance
(
x
,
Constant
)
for
x
in
topo_
[
0
]
.
inputs
)
assert
expected_len_sym_inputs
==
len
(
sym_inputs
)
assert
out_dtype
==
out
.
dtype
def
test_fusion_35_inputs
(
self
):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
inpts
=
vectors
([
"i
%
i"
%
i
for
i
in
range
(
35
)])
# Make an elemwise graph looking like:
# sin(i34 + sin(i33 + sin(... i1 + sin(i0) ...)))
out
=
sin
(
inpts
[
0
])
for
idx
in
range
(
1
,
35
):
out
=
sin
(
inpts
[
idx
]
+
out
)
with
config
.
change_flags
(
cxx
=
""
):
f
=
function
(
inpts
,
out
,
mode
=
self
.
mode
)
# Make sure they all weren't fused
composite_nodes
=
[
node
for
node
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
getattr
(
node
.
op
,
"scalar_op"
,
None
),
aes
.
basic
.
Composite
)
]
assert
not
any
(
len
(
node
.
inputs
)
>
31
for
node
in
composite_nodes
)
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"No cxx compiler"
)
def
test_big_fusion
(
self
):
# In the past, pickle of Composite generated in that case
# crashed with max recursion limit. So we were not able to
# generate C code in that case.
factors
=
[]
sd
=
dscalar
()
means
=
dvector
()
cst_05
=
at
.
constant
(
0.5
)
cst_m05
=
at
.
constant
(
-
0.5
)
cst_2
=
at
.
constant
(
2
)
cst_m2
=
at
.
constant
(
-
2
)
ones
=
at
.
constant
(
np
.
ones
(
10
))
n
=
85
if
config
.
mode
in
[
"DebugMode"
,
"DEBUG_MODE"
]:
n
=
10
for
i
in
range
(
n
):
f
=
cst_m05
*
sd
**
cst_m2
*
(
ones
-
means
[
i
])
**
cst_2
+
cst_05
*
log
(
cst_05
*
(
sd
**
cst_m2
)
/
np
.
pi
)
factors
.
append
(
at_sum
(
f
))
logp
=
add
(
*
factors
)
vars
=
[
sd
,
means
]
# Make sure that C compilation is used
mode
=
Mode
(
"cvm"
,
self
.
rewrites
)
dlogp
=
function
(
vars
,
[
aesara
.
grad
(
logp
,
v
)
for
v
in
vars
],
mode
=
mode
)
# Make sure something was fused
assert
any
(
isinstance
(
getattr
(
node
.
op
,
"scalar_op"
,
None
),
aes
.
basic
.
Composite
)
for
node
in
dlogp
.
maker
.
fgraph
.
toposort
()
)
def
test_add_mul_fusion_inplace
(
self
):
rewrites
=
RewriteDatabaseQuery
(
include
=
[
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
"inplace"
,
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
],
)
mode
=
Mode
(
self
.
mode
.
linker
,
rewrites
)
x
,
y
,
z
=
dmatrices
(
"xyz"
)
out
=
dot
(
x
,
y
)
+
x
+
y
+
z
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
mode
)
topo
=
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()]
assert
len
(
topo
)
==
2
assert
topo
[
-
1
]
.
op
.
inplace_pattern
new_out
=
f
.
maker
.
fgraph
.
outputs
[
0
]
assert
isinstance
(
new_out
.
owner
.
op
,
Elemwise
)
assert
isinstance
(
new_out
.
owner
.
op
.
scalar_op
,
aes
.
basic
.
Add
)
assert
len
(
new_out
.
owner
.
inputs
)
==
4
# TODO: Do we really need to do this?
_
=
f
(
np
.
random
.
random
((
5
,
5
)),
np
.
random
.
random
((
5
,
5
)),
np
.
random
.
random
((
5
,
5
))
)
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"No cxx compiler"
)
def
test_no_c_code
(
self
):
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
# This custom `Op` has no `c_code` method
class
NoCCodeOp
(
aes
.
basic
.
UnaryScalarOp
):
def
impl
(
self
,
x
):
return
x
*
2
no_c_code_op
=
Elemwise
(
NoCCodeOp
(
aes
.
basic
.
upgrade_to_float
))
mode
=
Mode
(
linker
=
"cvm"
)
mode
.
_optimizer
=
mode
.
_optimizer
.
including
(
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
"inplace"
,
)
x
=
vector
()
out
=
x
*
no_c_code_op
(
x
+
1
)
f
=
function
([
x
],
out
,
mode
=
mode
)
assert
not
any
(
isinstance
(
getattr
(
n
.
op
,
"scalar_op"
),
aes
.
basic
.
Composite
)
for
n
in
f
.
maker
.
fgraph
.
toposort
()
)
@pytest.mark.parametrize
(
"test_value"
,
[
np
.
c_
[[
1.0
]],
np
.
c_
[[]]])
def
test_test_values
(
self
,
test_value
):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""
rewrites
=
RewriteDatabaseQuery
(
include
=
[
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
],
)
mode
=
Mode
(
self
.
mode
.
linker
,
rewrites
)
x
,
y
,
z
=
dmatrices
(
"xyz"
)
x
.
tag
.
test_value
=
test_value
y
.
tag
.
test_value
=
test_value
z
.
tag
.
test_value
=
test_value
if
test_value
.
size
==
0
:
cm
=
pytest
.
raises
(
ValueError
)
else
:
cm
=
contextlib
.
suppress
()
with
config
.
change_flags
(
compute_test_value
=
"raise"
,
compute_test_value_opt
=
"raise"
):
out
=
x
*
y
+
z
with
cm
:
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
mode
)
if
test_value
.
size
!=
0
:
# Confirm that the fusion happened
assert
isinstance
(
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
.
scalar_op
,
Composite
)
assert
len
(
f
.
maker
.
fgraph
.
toposort
())
==
1
x_c
,
y_c
,
z_c
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
inputs
assert
np
.
array_equal
(
f
.
maker
.
fgraph
.
outputs
[
0
]
.
tag
.
test_value
,
np
.
c_
[[
2.0
]]
)
class
TimesN
(
aes
.
basic
.
UnaryScalarOp
):
"""
Used in test TestCompositeCodegen
Must be outside of the class, otherwise, the c cache code can't
pickle this class and this cause stuff printing during test.
"""
def
__eq__
(
self
,
other
):
return
super
()
.
__eq__
(
other
)
and
self
.
n
==
other
.
n
def
__hash__
(
self
):
return
super
()
.
__hash__
()
^
hash
(
self
.
n
)
def
__init__
(
self
,
n
,
*
args
,
**
kwargs
):
self
.
n
=
n
aes
.
basic
.
UnaryScalarOp
.
__init__
(
self
,
*
args
,
**
kwargs
)
def
impl
(
self
,
x
):
return
x
*
self
.
n
def
c_support_code_apply
(
self
,
node
,
nodename
):
n
=
str
(
self
.
n
)
return
(
"""
float
%(nodename)
s_timesn(float x) { return x *
%(n)
s; }
"""
%
locals
()
)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
(
z
,)
=
outputs
return
f
"{z} = {name}_timesn({x});"
class
TestCompositeCodegen
:
"""
Test The Composite Ops code generation in a case where there is multiple
scalar ops with support code.
"""
def
setup_method
(
self
):
upgrade_to_float
=
aes
.
basic
.
upgrade_to_float
self
.
scal_times_2
=
TimesN
(
2
,
upgrade_to_float
,
name
=
"times_2"
)
self
.
times_2
=
Elemwise
(
self
.
scal_times_2
,
name
=
"times_2"
)
self
.
scal_times_3
=
TimesN
(
3
,
upgrade_to_float
,
name
=
"times_3"
)
self
.
times_3
=
Elemwise
(
self
.
scal_times_3
,
name
=
"times_3"
)
self
.
x
=
fvector
()
def
test_nested_composite
(
self
):
y
=
self
.
times_2
(
self
.
x
)
z
=
self
.
times_3
(
y
)
f
=
function
([
self
.
x
],
z
)
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
f
.
maker
.
fgraph
.
toposort
())
==
1
fval
=
f
([
1
,
2
,
3
])
assert
np
.
all
(
fval
==
[
6
,
12
,
18
])
def
test_local_useless_composite
(
self
):
x
=
aes
.
float32
()
c
=
aes
.
Composite
([
x
],
[
x
+
1
,
x
-
1
])
X
=
matrix
()
o
=
Elemwise
(
scalar_op
=
c
)(
X
)
mode
=
get_default_mode
()
.
including
(
"local_useless_composite"
)
f
=
function
([
X
],
o
[
0
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
len
(
topo
[
0
]
.
outputs
)
==
1
utt
.
assert_allclose
(
f
([[
1.0
]]),
[[
2.0
]])
f
=
function
([
X
],
o
[
1
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
len
(
topo
[
0
]
.
outputs
)
==
1
utt
.
assert_allclose
(
f
([[
1.0
]]),
[[
0.0
]])
def
test_local_useless_slice
():
# test a simple matrix
x
=
matrix
(
"x"
)
...
...
@@ -1616,191 +462,6 @@ class TestLocalUselessIncSubtensorAlloc:
assert
check_stack_trace
(
f2
,
ops_to_check
=
"last"
)
class
TestShapeRewriter
:
def
test_basic
(
self
):
mode
=
config
.
mode
if
mode
==
"FAST_COMPILE"
:
mode
=
"FAST_RUN"
v
=
vector
()
m
=
matrix
()
f
=
function
([
v
,
m
],
(
v
+
m
)
.
shape
,
mode
=
mode
)
for
node
in
f
.
maker
.
fgraph
.
toposort
():
assert
node
.
op
!=
add
def
test_constant
(
self
):
mode
=
config
.
mode
if
mode
==
"FAST_COMPILE"
:
mode
=
"FAST_RUN"
v
=
vector
()
f
=
function
([
v
],
v
.
dimshuffle
(
"x"
,
"x"
,
0
)
.
shape
[
1
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
topo
[
0
]
.
op
==
deep_copy_op
@staticmethod
def
max_pool_c01b
(
c01b
,
pool_shp
,
pool_stride
,
img_shp
):
"""
Like max_pool but with input using axes ('c', 0, 1, 'b')
(Alex Krizhevsky format)
pool_shp, pool_stride and img_shp are int that represent
the same shp in x and y.
"""
mx
=
None
# Compute index in pooled space of last needed pool
# (needed = each input pixel must appear in at least one pool)
def
last_pool
(
im_shp
,
p_shp
,
p_strd
):
rval
=
int
(
np
.
ceil
(
float
(
im_shp
-
p_shp
)
/
p_strd
))
assert
p_strd
*
rval
+
p_shp
>=
im_shp
assert
p_strd
*
(
rval
-
1
)
+
p_shp
<
im_shp
return
rval
# Compute starting row of the last pool
last_pool_r
=
last_pool
(
img_shp
,
pool_shp
,
pool_stride
)
*
pool_stride
# Compute number of rows needed in img for all indexes to work out
required_r
=
last_pool_r
+
pool_shp
last_pool_c
=
last_pool
(
img_shp
,
pool_shp
,
pool_stride
)
*
pool_stride
required_c
=
last_pool_c
+
pool_shp
wide_infinity
=
at
.
alloc
(
-
np
.
inf
,
c01b
.
shape
[
0
],
required_r
,
required_c
,
c01b
.
shape
[
3
]
)
c01b
=
set_subtensor
(
wide_infinity
[:,
0
:
img_shp
,
0
:
img_shp
,
:],
c01b
)
for
row_within_pool
in
range
(
pool_shp
):
row_stop
=
last_pool_r
+
row_within_pool
+
1
for
col_within_pool
in
range
(
pool_shp
):
col_stop
=
last_pool_c
+
col_within_pool
+
1
cur
=
c01b
[
:,
row_within_pool
:
row_stop
:
pool_stride
,
col_within_pool
:
col_stop
:
pool_stride
,
:,
]
if
mx
is
None
:
mx
=
cur
else
:
mx
=
maximum
(
mx
,
cur
)
return
mx
def
test_broadcasted_dims
(
self
):
# This test a case that caused a crash during rewriting
shp
=
(
1
,
1
,
1
,
1
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
a
=
shared
(
rng
.
random
(
shp
)
.
astype
(
config
.
floatX
))
out
=
self
.
max_pool_c01b
(
a
,
1
,
1
,
1
)
# max_pool_c01b use -inf and this will trigger DebugMode error.
mode
=
copy
.
copy
(
get_default_mode
())
mode
.
check_isfinite
=
False
f
=
function
([],
out
,
mode
=
mode
)
f
()
def
test_constant_merge
(
self
):
# This test the error in gh-1122 that is a caused by the
# combination of merge rewriter and ShapeFeature.
x
=
at
.
constant
([
0
,
0
])
y
=
x
[
1
:]
x1
=
x
-
at
.
join
(
0
,
y
,
y
)
x1
.
eval
()
def
test_local_track_shape_i
(
self
):
class
IdentityNoShape
(
Op
):
"""Op that does not infer the output shape from the input one"""
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
out
[
0
]
=
x
.
copy
()
# def infer_shape(self, fgraph, node, (xshp,)):
# return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])]
identity_noshape
=
IdentityNoShape
()
class
IdentityShape
(
Op
):
"""Op that does infer the output shape from the input one"""
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
out
[
0
]
=
x
.
copy
()
def
infer_shape
(
self
,
fgraph
,
node
,
xshp_
):
# Could also just return.
(
xshp
,)
=
xshp_
return
(
xshp
,)
identity_shape
=
IdentityShape
()
@node_rewriter
([
IdentityNoShape
])
def
local_identity_noshape_to_identity_shape
(
fgraph
,
node
):
"""Transform the first `Op` into the second."""
if
isinstance
(
node
.
op
,
IdentityNoShape
):
return
[
identity_shape
(
node
.
inputs
[
0
])]
mode
=
get_default_mode
()
.
including
(
"ShapeOpt"
,
"specialize"
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
x
=
tensor3
(
"x"
)
ins_x
=
identity_noshape
(
x
)
# Without the rewrite
f
=
function
([
x
],
ins_x
.
shape
,
mode
=
mode
)
xval
=
rng
.
standard_normal
((
3
,
4
,
7
))
.
astype
(
config
.
floatX
)
assert
np
.
all
(
f
(
xval
)
==
[
3
,
4
,
7
])
f_ops
=
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
assert
len
(
f_ops
)
==
5
assert
identity_noshape
in
f_ops
assert
identity_shape
not
in
f_ops
# Register the rewrite
register_specialize
(
local_identity_noshape_to_identity_shape
)
mode
=
get_default_mode
()
.
including
(
"ShapeOpt"
,
"specialize"
)
# The `identity_shape` hOph should not be needed anymore to compute
# the shape
g
=
function
([
x
],
ins_x
.
shape
,
mode
=
mode
)
xval
=
rng
.
standard_normal
((
6
,
1
,
2
))
.
astype
(
config
.
floatX
)
assert
np
.
all
(
g
(
xval
)
==
[
6
,
1
,
2
])
g_ops
=
[
node
.
op
for
node
in
g
.
maker
.
fgraph
.
toposort
()]
assert
len
(
g_ops
)
==
4
assert
identity_noshape
not
in
g_ops
assert
identity_shape
not
in
g_ops
# Test multiple applications of an `Op` without an `Op.infer_shape`
ins_x3
=
identity_noshape
(
identity_noshape
(
identity_noshape
(
x
)))
h
=
function
([
x
],
ins_x3
.
shape
,
mode
=
mode
)
xval
=
rng
.
standard_normal
((
6
,
1
,
2
))
.
astype
(
config
.
floatX
)
assert
np
.
all
(
h
(
xval
)
==
[
6
,
1
,
2
])
h_ops
=
[
node
.
op
for
node
in
h
.
maker
.
fgraph
.
toposort
()]
assert
len
(
h_ops
)
==
4
assert
identity_noshape
not
in
h_ops
assert
identity_shape
not
in
h_ops
def
test_no_shapeopt
(
self
):
"""Test that a basic example works even when `ShapeOpt` is excluded."""
X
=
matrix
()
expr
=
X
.
shape
[
0
]
mode
=
get_default_mode
()
.
excluding
(
"ShapeOpt"
)
f
=
function
([
X
],
expr
,
mode
=
mode
)
# FIXME: This is not a good test.
f
([[
1
,
2
],
[
2
,
3
]])
class
TestUselessCheckAndRaise
:
def
test_basic
(
self
):
mode
=
get_default_mode
()
.
including
(
...
...
@@ -2739,136 +1400,6 @@ def test_local_flatten_lift(i):
assert
isinstance
(
topo
[
-
1
]
.
op
,
Elemwise
)
class
TestReshape
:
def
setup_method
(
self
):
self
.
mode
=
rewrite_mode
self
.
op
=
Reshape
def
test_local_reshape
(
self
):
a
=
fmatrix
()
b
=
self
.
op
(
3
)(
a
,
[
2
,
3
,
4
])
c
=
self
.
op
(
1
)(
b
,
[
24
])
f
=
function
([
a
],
c
,
mode
=
self
.
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
self
.
op
)
for
node
in
topo
)
==
1
# Check stack trace
assert
check_stack_trace
(
f
,
ops_to_check
=
[
self
.
op
])
class
TestLocalUselessReshape
:
def
setup_method
(
self
):
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
test_0
(
self
):
mode
=
get_default_mode
()
.
including
(
"local_useless_reshape"
)
i
=
iscalar
(
"i"
)
m
=
at
.
mgrid
[
0
:
i
,
]
f
=
function
([
i
],
m
,
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
def
test_1
(
self
):
x
=
matrix
(
"x"
)
r
=
x
.
reshape
(
x
.
shape
)
m0
=
get_default_mode
()
m1
=
m0
.
including
(
"local_useless_reshape"
)
f1
=
function
([
x
],
r
,
mode
=
m1
)
topo
=
f1
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
m2
=
m1
.
excluding
(
"ShapeOpt"
)
f2
=
function
([
x
],
r
,
mode
=
m2
)
topo
=
f2
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
# We do not need tests checking that stack traces are copied over,
# because local_useless_reshape only removes nodes from the graph
def
test_2
(
self
):
x
=
matrix
(
"x"
)
r
=
x
.
reshape
([
Shape_i
(
i
)(
x
)
for
i
in
range
(
x
.
ndim
)])
m0
=
get_default_mode
()
m1
=
m0
.
including
(
"local_useless_reshape"
)
f1
=
function
([
x
],
r
,
mode
=
m1
)
topo
=
f1
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
m2
=
m1
.
excluding
(
"ShapeOpt"
)
f2
=
function
([
x
],
r
,
mode
=
m2
)
topo
=
f2
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
def
test_m1
(
self
):
x
=
matrix
(
"x"
)
r
=
x
.
reshape
((
x
.
shape
[
0
],
-
1
))
m0
=
get_default_mode
()
m1
=
m0
.
including
(
"local_useless_reshape"
)
f1
=
function
([
x
],
r
,
mode
=
m1
)
topo
=
f1
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
m2
=
m1
.
excluding
(
"ShapeOpt"
)
f2
=
function
([
x
],
r
,
mode
=
m2
)
topo
=
f2
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
class
TestLocalReshapeToDimshuffle
:
def
setup_method
(
self
):
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
test_1
(
self
):
reshape_lift
=
out2in
(
local_reshape_to_dimshuffle
)
useless_reshape
=
out2in
(
local_useless_reshape
)
x
=
shared
(
self
.
rng
.
standard_normal
((
4
,)))
y
=
shared
(
self
.
rng
.
standard_normal
((
5
,
6
)))
reshape_x
=
reshape
(
x
,
(
1
,
4
))
reshape_y
=
reshape
(
y
,
(
1
,
5
,
1
,
6
,
1
,
1
))
g
=
FunctionGraph
([
x
,
y
],
[
reshape_x
,
reshape_y
])
assert
str
(
g
)
==
(
"FunctionGraph(Reshape{2}"
"(<TensorType(float64, (None,))>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, (None, None))>, "
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift
.
rewrite
(
g
)
useless_reshape
.
rewrite
(
g
)
assert
str
(
g
)
==
(
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,0,x,1,x,x}"
"(Reshape{2}(<TensorType(float64, (None, None))>, "
"TensorConstant{[5 6]})))"
)
# Check stacktrace was copied over correctly after the rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
(
DimShuffle
,
Reshape
))
def
test_local_reshape_lift
():
x
=
tensor4
()
out
=
exp
(
x
)
.
reshape
([
x
.
size
])
assert
out
.
ndim
==
1
mode
=
get_default_mode
()
mode
=
mode
.
including
(
"local_reshape_lift"
)
f
=
function
([
x
],
out
,
mode
=
mode
)
f
(
np
.
random
.
random
((
5
,
4
,
3
,
2
))
.
astype
(
config
.
floatX
))
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
topo
[
-
2
]
.
op
,
Reshape
)
assert
isinstance
(
topo
[
-
1
]
.
op
,
Elemwise
)
assert
check_stack_trace
(
f
,
ops_to_check
=
"last"
)
class
TestLiftTransposeThroughDot
:
def
simple_rewrite
(
self
,
g
):
out2in
(
local_useless_elemwise
)
.
rewrite
(
g
)
...
...
@@ -2918,160 +1449,6 @@ def test_local_upcast_elemwise_constant_inputs():
function
([
v
],
true_div
(
v
,
2
))
class
TestShapeI
(
utt
.
InferShapeTester
):
def
setup_method
(
self
):
super
()
.
setup_method
()
def
test_perform
(
self
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
advec
=
vector
()
advec_val
=
rng
.
random
((
3
))
.
astype
(
config
.
floatX
)
f
=
function
([
advec
],
Shape_i
(
0
)(
advec
))
out
=
f
(
advec_val
)
utt
.
assert_allclose
(
out
,
advec_val
.
shape
[
0
])
admat
=
matrix
()
admat_val
=
rng
.
random
((
4
,
3
))
.
astype
(
config
.
floatX
)
for
i
in
range
(
2
):
f
=
function
([
admat
],
Shape_i
(
i
)(
admat
))
out
=
f
(
admat_val
)
utt
.
assert_allclose
(
out
,
admat_val
.
shape
[
i
])
def
test_infer_shape
(
self
):
admat
=
matrix
()
admat_val
=
np
.
random
.
random
((
3
,
4
))
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
([
admat
],
[
Shape_i
(
0
)(
admat
)],
[
admat_val
],
Shape_i
)
self
.
_compile_and_check
([
admat
],
[
Shape_i
(
1
)(
admat
)],
[
admat_val
],
Shape_i
)
class
TestSameShape
:
def
test_scalar
(
self
):
x
=
scalar
()
cst
=
at
.
constant
(
1
)
o
=
x
+
cst
fgraph
=
FunctionGraph
([
x
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
assert
shape_feature
.
same_shape
(
x
,
o
)
def
test_vector
(
self
):
x
=
vector
()
cst
=
at
.
constant
(
1
)
o
=
x
+
cst
fgraph
=
FunctionGraph
([
x
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
assert
shape_feature
.
same_shape
(
x
,
o
)
def
test_no_static_shapes
(
self
):
x
=
vector
()
y
=
vector
()
o
=
x
+
y
fgraph
=
FunctionGraph
([
x
,
y
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
# We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert
not
shape_feature
.
same_shape
(
x
,
o
)
# The following case isn't implemented
assert
not
shape_feature
.
same_shape
(
y
,
o
)
@pytest.mark.parametrize
(
"y_dim_0"
,
[
2
,
pytest
.
param
(
None
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Not implemented"
))],
)
def
test_vector_dim
(
self
,
y_dim_0
):
x
=
at
.
tensor
(
dtype
=
"floatX"
,
shape
=
(
2
,
None
))
y
=
at
.
tensor
(
dtype
=
"floatX"
,
shape
=
(
y_dim_0
,
None
))
o
=
x
+
y
fgraph
=
FunctionGraph
([
x
,
y
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
assert
shape_feature
.
same_shape
(
x
,
o
,
0
,
0
)
assert
not
shape_feature
.
same_shape
(
x
,
o
,
1
,
1
)
def
test_vector_dim_err
(
self
):
x
=
vector
()
y
=
vector
()
o
=
x
+
y
fgraph
=
FunctionGraph
([
x
,
y
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
with
pytest
.
raises
(
IndexError
):
shape_feature
.
same_shape
(
x
,
o
,
1
,
0
)
with
pytest
.
raises
(
IndexError
):
shape_feature
.
same_shape
(
x
,
o
,
0
,
1
)
@pytest.mark.parametrize
(
"shape"
,
[
lscalar
(),
iscalar
()],
)
def
test_local_Shape_of_SpecifyShape
(
shape
):
x
=
vector
()
s
=
specify_shape
(
x
,
shape
)
.
shape
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
x
not
in
fgraph
.
variables
assert
shape
in
fgraph
.
variables
@pytest.mark.parametrize
(
"s1"
,
[
lscalar
(),
iscalar
()],
)
def
test_local_Shape_of_SpecifyShape_partial
(
s1
):
x
=
matrix
()
s
=
specify_shape
(
x
,
(
s1
,
None
))
.
shape
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
assert
any
(
isinstance
(
apply
.
op
,
SpecifyShape
)
for
apply
in
fgraph
.
apply_nodes
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
x
in
fgraph
.
variables
assert
s1
in
fgraph
.
variables
assert
not
any
(
isinstance
(
apply
.
op
,
SpecifyShape
)
for
apply
in
fgraph
.
apply_nodes
)
def
test_local_Shape_i_of_broadcastable
():
x
=
tensor
(
np
.
float64
,
[
False
,
True
])
s
=
Shape_i
(
1
)(
x
)
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
x
not
in
fgraph
.
variables
assert
fgraph
.
outputs
[
0
]
.
data
==
1
# A test for a non-`TensorType`
class
MyType
(
Type
):
ndim
=
1
def
filter
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
class
MyVariable
(
Variable
):
pass
x
=
MyVariable
(
MyType
(),
None
,
None
)
s
=
Shape_i
(
0
)(
x
)
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
fgraph
.
outputs
[
0
]
==
s
def
test_assert_op_gradient
():
x
=
vector
(
"x"
)
assert_op
=
Assert
()
...
...
@@ -3183,283 +1560,6 @@ def test_local_useless_alloc():
assert
isinstance
(
topo
[
-
1
]
.
op
,
Alloc
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_scalar
(
return_index
,
return_counts
,
return_inverse
):
x
=
dscalar
()
y
=
unique
(
x
,
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
None
,
)
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_scalar"
]
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
DimShuffle
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
default_mode
=
get_default_mode
()
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_scalar"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
x_val
=
np
.
array
(
-
10.0
,
dtype
=
np
.
float64
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
()),
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_Alloc_lift
(
x_val
,
axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
alloc
(
x
,
*
new_shape
),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_Alloc_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
Alloc
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
rewrite_mode
=
default_mode
.
excluding
(
"local_useless_alloc"
,
"local_alloc_sink_dimshuffle"
,
"local_Unique_Alloc_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `Alloc` is used to compute the reference `y`
# result
assert
any
(
isinstance
(
node
.
op
,
Alloc
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_BroadcastTo
(
x_val
,
axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
BroadcastTo
()(
x
,
tuple
(
new_shape
)),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_BroadcastTo_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_BroadcastTo_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, unique_axis, repeats, repeat_axis"
,
[
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
1
,
2
),
0
),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_Repeat
(
x_val
,
unique_axis
,
repeats
,
repeat_axis
,
return_index
,
return_counts
,
return_inverse
,
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
repeat
(
x
,
tuple
(
repeats
),
axis
=
repeat_axis
),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
unique_axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_Repeat_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
Repeat
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_Repeat_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
,
Repeat
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, unique_axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
()),
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_second
(
x_val
,
unique_axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
a
=
np
.
zeros
(
tuple
(
new_shape
),
dtype
=
x
.
dtype
)
y
=
unique
(
second
(
a
,
x
),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
unique_axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_second_lift"
],
exclude
=
[
"local_Unique_scalar"
,
"topo_constant_folding"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
y_rewritten_start
=
y_rewritten_start
.
owner
.
inputs
[
0
]
if
y_rewritten_start
.
owner
and
isinstance
(
y_rewritten_start
.
owner
.
op
,
DimShuffle
):
y_rewritten_start
=
y_rewritten_start
.
owner
.
inputs
[
0
]
assert
y_rewritten_start
==
x
assert
not
any
(
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Second
)
for
node
in
y_rewritten_fg
.
apply_nodes
if
isinstance
(
node
.
op
,
Elemwise
)
)
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
Mode
(
optimizer
=
OPT_NONE
))
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Second
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Elemwise
)
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
def
test_local_merge_consecutive_specify_shape
():
x
=
matrix
()
s
=
at
.
as_tensor
([
iscalar
(),
iscalar
()])
...
...
@@ -3501,64 +1601,6 @@ def test_printing():
assert
pprint
(
v
)
==
"[a, b]"
def
test_local_remove_scalar_BroadcastTo
():
x
=
dscalar
()
y
=
BroadcastTo
()(
x
,
())
assert
isinstance
(
y
.
owner
.
op
,
BroadcastTo
)
res
=
rewrite_graph
(
y
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_remove_scalar_BroadcastTo"
]
)
assert
res
is
x
def
test_local_useless_dimshuffle_makevector
():
a
=
scalar
()
x
=
MakeVector
(
config
.
floatX
)(
a
)
y
=
x
.
dimshuffle
(())
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_useless_dimshuffle_makevector"
],
)
assert
y_rewritten_fg
.
outputs
[
0
]
==
a
def
test_Shape_i_canonicalize
():
"""Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension.
In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)``
and nothing else. The rewrites `local_shape_to_shape_i`,
`local_subtensor_remove_broadcastable_index`, and
`local_useless_dimshuffle_makevector` need to work together to accomplish
this, and we confirm that here.
"""
x
=
vector
()
y
=
shape
(
x
)[
0
]
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
,
features
=
[
ShapeFeature
()])
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
assert
isinstance
(
y_rewritten
.
owner
.
op
,
Shape_i
)
assert
y_rewritten
.
owner
.
op
.
i
==
0
assert
y_rewritten
.
owner
.
inputs
[
0
]
==
x
class
TestLocalElemwiseAlloc
:
"""
...
...
@@ -3847,3 +1889,6 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with
pytest
.
deprecated_call
():
from
aesara.tensor.basic_opt
import
register_useless
# noqa: F401 F811
with
pytest
.
deprecated_call
():
from
aesara.tensor.rewriting.basic
import
ShapeFeature
# noqa: F401
tests/tensor/rewriting/test_elemwise.py
0 → 100644
浏览文件 @
63f52536
import
contextlib
import
numpy
as
np
import
pytest
import
aesara
import
aesara.scalar
as
aes
import
aesara.tensor
as
at
from
aesara
import
shared
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
Mode
,
get_default_mode
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.rewriting.basic
import
check_stack_trace
,
out2in
from
aesara.graph.rewriting.db
import
RewriteDatabaseQuery
from
aesara.graph.rewriting.utils
import
rewrite_graph
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.scalar.basic
import
Composite
from
aesara.tensor.basic
import
MakeVector
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.math
import
(
add
,
bitwise_and
,
bitwise_or
,
cos
,
cosh
,
dot
,
eq
,
exp
,
int_div
,
invert
,
iround
,
log
,
log2
,
log10
,
mul
,
neg
,
neq
,
)
from
aesara.tensor.math
import
pow
as
at_pow
from
aesara.tensor.math
import
reciprocal
from
aesara.tensor.math
import
round
as
at_round
from
aesara.tensor.math
import
sin
,
sinh
,
sqr
,
sqrt
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
tan
,
tanh
,
true_div
,
xor
from
aesara.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
aesara.tensor.rewriting.shape
import
local_useless_dimshuffle_in_reshape
from
aesara.tensor.shape
import
reshape
from
aesara.tensor.type
import
(
TensorType
,
dmatrices
,
dscalar
,
dvector
,
fscalar
,
fvector
,
matrix
,
scalar
,
tensor
,
vector
,
vectors
,
)
from
tests
import
unittest_tools
as
utt
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
x
=
TensorType
(
shape
=
xbc
,
dtype
=
"float64"
)(
"x"
)
y
=
TensorType
(
shape
=
ybc
,
dtype
=
"float64"
)(
"y"
)
z
=
TensorType
(
shape
=
zbc
,
dtype
=
"float64"
)(
"z"
)
return
x
,
y
,
z
class
TestDimshuffleLift
:
def
test_double_transpose
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
x
,
(
1
,
0
)),
(
1
,
0
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
(
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
# no need to check_stack_trace as graph is supposed to be empty
def
test_merge2
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
x
,
(
1
,
"x"
,
0
)),
(
2
,
0
,
"x"
,
1
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
(
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
),
str
(
g
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))"
,
str
(
g
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
"all"
)
def
test_elim3
(
self
):
x
,
y
,
z
=
inputs
()
e
=
ds
(
ds
(
ds
(
x
,
(
0
,
"x"
,
1
)),
(
2
,
0
,
"x"
,
1
)),
(
1
,
0
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
str
(
g
)
==
(
"FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
"(InplaceDimShuffle{0,x,1}(x))))"
),
str
(
g
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
,
str
(
g
)
# no need to check_stack_trace as graph is supposed to be empty
def
test_lift
(
self
):
x
,
y
,
z
=
inputs
([
False
]
*
1
,
[
False
]
*
2
,
[
False
]
*
3
)
e
=
x
+
y
+
z
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
# It does not really matter if the DimShuffles are inplace
# or not.
init_str_g_inplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
)
init_str_g_noinplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
)
assert
str
(
g
)
in
(
init_str_g_inplace
,
init_str_g_noinplace
),
str
(
g
)
rewrite_str_g_inplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
)
rewrite_str_g_noinplace
=
(
"FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
)
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
in
(
rewrite_str_g_inplace
,
rewrite_str_g_noinplace
),
str
(
g
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
"all"
)
def
test_recursive_lift
(
self
):
v
=
vector
(
dtype
=
"float64"
)
m
=
matrix
(
dtype
=
"float64"
)
out
=
((
v
+
42
)
*
(
m
+
84
))
.
T
g
=
FunctionGraph
([
v
,
m
],
[
out
])
init_str_g
=
(
"FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
"(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
"(<TensorType(float64, (None,))>, "
"InplaceDimShuffle{x}(TensorConstant{42}))), "
"Elemwise{add,no_inplace}"
"(<TensorType(float64, (None, None))>, "
"InplaceDimShuffle{x,x}(TensorConstant{84})))))"
)
assert
str
(
g
)
==
init_str_g
new_out
=
local_dimshuffle_lift
.
transform
(
g
,
g
.
outputs
[
0
]
.
owner
)[
0
]
new_g
=
FunctionGraph
(
g
.
inputs
,
[
new_out
])
rewrite_str_g
=
(
"FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{0,x}(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,x}(TensorConstant{42})), "
"Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
"(<TensorType(float64, (None, None))>), "
"InplaceDimShuffle{x,x}(TensorConstant{84}))))"
)
assert
str
(
new_g
)
==
rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
new_g
,
ops_to_check
=
"all"
)
def
test_useless_dimshuffle
(
self
):
x
,
_
,
_
=
inputs
()
e
=
ds
(
x
,
(
0
,
1
))
g
=
FunctionGraph
([
x
],
[
e
])
assert
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{0,1}(x))"
dimshuffle_lift
.
rewrite
(
g
)
assert
str
(
g
)
==
"FunctionGraph(x)"
# Check stacktrace was copied over correctly after rewrite was applied
assert
hasattr
(
g
.
outputs
[
0
]
.
tag
,
"trace"
)
def
test_dimshuffle_on_broadcastable
(
self
):
x
,
y
,
z
=
inputs
([
False
,
True
],
[
True
,
False
,
True
],
[
False
,
False
,
True
])
u
=
at
.
constant
(
1
)
ds_x
=
ds
(
x
,
(
0
,
"x"
))
# useless
ds_y
=
ds
(
y
,
(
2
,
1
,
0
))
# useless
ds_z
=
ds
(
z
,
(
2
,
1
,
0
))
# useful
ds_u
=
ds
(
u
,
(
"x"
))
# useful
g
=
FunctionGraph
([
x
,
y
,
z
,
u
],
[
ds_x
,
ds_y
,
ds_z
,
ds_u
])
assert
(
str
(
g
)
==
"FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
dimshuffle_lift
.
rewrite
(
g
)
assert
(
str
(
g
)
==
"FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
hasattr
(
g
.
outputs
[
0
]
.
tag
,
"trace"
)
def
test_local_useless_dimshuffle_in_reshape
():
vec
=
TensorType
(
shape
=
(
False
,),
dtype
=
"float64"
)(
"vector"
)
mat
=
TensorType
(
shape
=
(
False
,
False
),
dtype
=
"float64"
)(
"mat"
)
row
=
TensorType
(
shape
=
(
True
,
False
),
dtype
=
"float64"
)(
"row"
)
col
=
TensorType
(
shape
=
(
False
,
True
),
dtype
=
"float64"
)(
"col"
)
reshape_dimshuffle_vector
=
reshape
(
vec
.
dimshuffle
(
"x"
,
0
),
vec
.
shape
)
reshape_dimshuffle_mat
=
reshape
(
mat
.
dimshuffle
(
"x"
,
0
,
"x"
,
1
),
mat
.
shape
)
reshape_dimshuffle_row
=
reshape
(
row
.
dimshuffle
(
1
,
"x"
),
row
.
shape
)
reshape_dimshuffle_col
=
reshape
(
col
.
dimshuffle
(
0
),
col
.
shape
)
g
=
FunctionGraph
(
[
vec
,
mat
,
row
,
col
],
[
reshape_dimshuffle_vector
,
reshape_dimshuffle_mat
,
reshape_dimshuffle_row
,
reshape_dimshuffle_col
,
],
)
assert
str
(
g
)
==
(
"FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
"Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
"Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
"Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
)
useless_dimshuffle_in_reshape
=
out2in
(
local_useless_dimshuffle_in_reshape
)
useless_dimshuffle_in_reshape
.
rewrite
(
g
)
assert
str
(
g
)
==
(
"FunctionGraph(Reshape{1}(vector, Shape(vector)), "
"Reshape{2}(mat, Shape(mat)), "
"Reshape{2}(row, Shape(row)), "
"Reshape{2}(col, Shape(col)))"
)
# Check stacktrace was copied over correctly after rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
"all"
)
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2
=
reshape
(
mat
.
dimshuffle
(
"x"
,
1
,
"x"
,
0
),
mat
.
shape
)
h
=
FunctionGraph
([
mat
],
[
reshape_dimshuffle_mat2
])
str_h
=
str
(
h
)
useless_dimshuffle_in_reshape
.
rewrite
(
h
)
assert
str
(
h
)
==
str_h
class
TestFusion
:
rewrites
=
RewriteDatabaseQuery
(
include
=
[
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
"inplace"
,
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
],
)
mode
=
Mode
(
get_default_mode
()
.
linker
,
rewrites
)
_shared
=
staticmethod
(
shared
)
topo_exclude
=
()
def
my_init
(
dtype
=
"float64"
,
num
=
0
):
return
np
.
zeros
((
5
,
5
),
dtype
=
dtype
)
+
num
fw
,
fx
,
fy
,
fz
=
[
tensor
(
dtype
=
"float32"
,
shape
=
[
False
]
*
2
,
name
=
n
)
for
n
in
"wxyz"
]
dw
,
dx
,
dy
,
dz
=
[
tensor
(
dtype
=
"float64"
,
shape
=
[
False
]
*
2
,
name
=
n
)
for
n
in
"wxyz"
]
ix
,
iy
,
iz
=
[
tensor
(
dtype
=
"int32"
,
shape
=
[
False
]
*
2
,
name
=
n
)
for
n
in
"xyz"
]
fv
=
fvector
(
"v"
)
fs
=
fscalar
(
"s"
)
fwv
=
my_init
(
"float32"
,
1
)
fxv
=
my_init
(
"float32"
,
2
)
fyv
=
my_init
(
"float32"
,
3
)
fzv
=
my_init
(
"float32"
,
4
)
fvv
=
_asarray
(
np
.
random
.
random
(
5
),
dtype
=
"float32"
)
fsv
=
np
.
asarray
(
np
.
random
.
random
(),
dtype
=
"float32"
)
dwv
=
my_init
(
"float64"
,
5
)
ixv
=
_asarray
(
my_init
(
num
=
60
),
dtype
=
"int32"
)
iyv
=
_asarray
(
my_init
(
num
=
70
),
dtype
=
"int32"
)
izv
=
_asarray
(
my_init
(
num
=
70
),
dtype
=
"int32"
)
fwx
=
fw
+
fx
ftanx
=
tan
(
fx
)
@pytest.mark.parametrize
(
"case"
,
[
(
fx
+
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
fzv
,
"float32"
,
),
# 0
(
fx
*
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
*
fzv
,
"float32"
,
),
# 1
(
fx
+
fy
*
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
*
fzv
,
"float32"
,
),
# 2
(
fx
*
fy
+
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
+
fzv
,
"float32"
,
),
# 3
(
fw
+
fx
+
fy
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
# 5
(
((
fw
+
fx
)
+
fy
)
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
(
fx
+
fy
))
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
(
fx
+
fy
)
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
fw
+
(
fx
+
(
fy
+
fz
)),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
(
(
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
# 10
(
fw
*
fx
*
fy
*
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
*
fxv
*
fyv
*
fzv
,
"float32"
,
),
(
fw
+
fx
*
fy
*
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
*
fyv
*
fzv
,
"float32"
,
),
(
fx
+
fy
*
fz
*
fx
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
*
fzv
*
fxv
,
"float32"
,
),
(
fx
*
fy
+
fz
+
fy
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
+
fzv
+
fyv
,
"float32"
,
),
(
fx
*
fy
*
fz
*
fw
+
fx
+
fy
+
fz
+
fw
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fxv
*
fyv
*
fzv
*
fwv
+
fxv
+
fyv
+
fzv
+
fwv
,
"float32"
,
),
# 15
# test with constant
(
(
fw
+
fx
)
+
(
fy
+
fz
)
+
2.0
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
((
fw
+
fx
)
+
2.0
+
fy
)
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
(
fw
+
(
fx
+
2.0
+
fy
))
+
fz
,
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
(
fw
+
(
fx
+
fy
)
+
2
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
(
fw
+
(
fx
+
(
fy
+
fz
)
+
2.0
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
# 20
(
2
+
(
fw
+
fx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
1
,
fwv
+
fxv
+
fyv
+
fzv
+
2
,
"float32"
,
),
# mix float32 and float64
(
2
+
(
dw
+
fx
)
+
(
fy
+
fz
),
(
dw
,
fx
,
fy
,
fz
),
(
dwv
,
fxv
,
fyv
,
fzv
),
1
,
dwv
+
fxv
+
fyv
+
fzv
+
2
,
"float64"
,
),
(
2
+
(
fw
+
dw
)
+
(
fy
+
fz
),
(
fw
,
dw
,
fy
,
fz
),
(
fwv
,
dwv
,
fyv
,
fzv
),
1
,
fwv
+
dwv
+
fyv
+
fzv
+
2
,
"float64"
,
),
(
2
+
(
fw
+
fx
)
+
(
dw
+
fz
),
(
fw
,
fx
,
dw
,
fz
),
(
fwv
,
fxv
,
dwv
,
fzv
),
1
,
fwv
+
fxv
+
dwv
+
fzv
+
2
,
"float64"
,
),
(
2
+
(
fw
+
fx
)
+
(
fy
+
dw
),
(
fw
,
fx
,
fy
,
dw
),
(
fwv
,
fxv
,
fyv
,
dwv
),
1
,
fwv
+
fxv
+
fyv
+
dwv
+
2
,
"float64"
,
),
# 25
# test when their is other op then elemwise.
(
(
fwx
.
sum
())
+
(
fwx
)
+
(
fy
+
fz
),
(
fw
,
fx
,
fy
,
fz
),
(
fwv
,
fxv
,
fyv
,
fzv
),
4
,
(
fwv
+
fxv
)
.
sum
()
+
fwv
+
fxv
+
fyv
+
fzv
,
"float32"
,
),
# test other elemwise op
(
fx
+
fy
+
cos
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
cos
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
cosh
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
cosh
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
abs
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
absolute
(
fzv
),
"float32"
,
),
(
ix
+
iy
+
abs
(
iz
),
(
ix
,
iy
,
iz
),
(
ixv
,
iyv
,
izv
),
1
,
ixv
+
iyv
+
np
.
absolute
(
izv
),
"int32"
,
),
# 30
(
fx
+
fy
+
log
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
log
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
log2
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
log2
(
fzv
),
"float32"
,
),
(
fx
+
fy
+
log10
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
log10
(
fzv
),
"float32"
,
),
(
fx
+
fy
**
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
**
fzv
,
"float32"
,
),
# pow
(
fx
+
fy
+
exp
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
+
fyv
+
np
.
exp
(
fzv
),
"float32"
,
),
# 35
(
fx
-
fy
-
fz
,
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
-
fzv
,
"float32"
,
),
(
fx
-
(
fy
/
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
/
fzv
),
"float32"
,
),
(
fx
-
true_div
(
fy
,
2
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
fxv
-
(
fyv
/
2
),
"float32"
,
),
(
fx
-
true_div
(
fy
,
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
/
fzv
),
"float32"
,
),
(
fx
-
int_div
(
ix
*
100
,
iy
*
1000
),
(
fx
,
ix
,
iy
),
(
fxv
,
ixv
,
iyv
),
1
,
fxv
-
((
ixv
*
100
)
//
(
iyv
*
1000
)),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
# 40
(
fx
-
(
fy
/
2
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
fxv
-
(
fyv
/
2
),
"float32"
),
(
fx
-
(
fy
%
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
%
fzv
),
"float32"
,
),
(
fx
-
(
fy
>
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
>
fzv
),
"float32"
,
),
(
fx
-
(
fy
>=
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
>=
fzv
),
"float32"
,
),
(
fx
-
(
fy
<
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
<
fzv
),
"float32"
,
),
# 45
(
fx
-
(
fy
<=
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
<=
fzv
),
"float32"
,
),
(
fx
-
eq
(
fy
,
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
==
fzv
),
"float32"
,
),
(
fx
-
neq
(
fy
,
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
(
fyv
!=
fzv
),
"float32"
,
),
(
fx
-
fy
+
tan
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
tan
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
tanh
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
tanh
(
fzv
),
"float32"
,
),
# 50
(
fx
-
fy
+
sin
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
sin
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
sinh
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
sinh
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
sqr
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
(
fzv
*
fzv
),
"float32"
,
),
(
fx
-
fy
+
sqrt
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
sqrt
(
fzv
),
"float32"
,
),
(
fx
-
fy
+
reciprocal
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
(
1
/
fzv
),
"float32"
,
),
# 55
(
fx
-
fy
+
neg
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
(
-
fzv
),
"float32"
,
),
(
fx
-
fy
+
at_round
(
fz
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
fxv
-
fyv
+
np
.
round
(
fzv
),
"float32"
,
),
(
ix
-
iy
+
iround
(
fz
),
(
ix
,
iy
,
fz
),
(
ixv
,
iyv
,
fzv
),
1
,
ixv
-
iyv
+
np
.
round
(
fzv
),
"int64"
,
),
# Bit op
(
fx
-
bitwise_or
(
iy
,
iz
),
(
fx
,
iy
,
iz
),
(
fxv
,
iyv
,
izv
),
1
,
fxv
-
(
iyv
|
izv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
(
fx
-
xor
(
iy
,
iz
),
(
fx
,
iy
,
iz
),
(
fxv
,
iyv
,
izv
),
1
,
fxv
-
(
iyv
^
izv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
# 60
(
fx
-
bitwise_and
(
iy
,
iz
),
(
fx
,
iy
,
iz
),
(
fxv
,
iyv
,
izv
),
1
,
fxv
-
(
iyv
&
izv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
(
fx
-
invert
(
iy
),
(
fx
,
iy
),
(
fxv
,
iyv
),
1
,
fxv
-
(
~
iyv
),
{
"custom"
:
"float64"
,
"numpy + floatX"
:
config
.
floatX
,
"numpy"
:
"float64"
,
},
),
(
fx
-
at
.
cast
(
fy
,
dtype
=
"float64"
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
fxv
-
np
.
asarray
(
fyv
,
"float64"
),
"float64"
,
),
(
at_pow
(
fx
*
fy
+
fz
,
fx
*
fy
),
(
fx
,
fy
,
fz
),
(
fxv
,
fyv
,
fzv
),
1
,
np
.
power
(
fxv
*
fyv
+
fzv
,
fxv
*
fyv
),
"float32"
,
),
(
fv
+
fy
**
fz
,
(
fv
,
fy
,
fz
),
(
fvv
,
fyv
,
fzv
),
2
,
fvv
+
fyv
**
fzv
,
"float32"
,
),
# fused with a dimshuffle #65
(
fv
-
fy
+
tanh
(
fz
),
(
fv
,
fy
,
fz
),
(
fvv
,
fyv
,
fzv
),
2
,
fvv
-
fyv
+
np
.
tanh
(
fzv
),
"float32"
,
),
# fused with a dimshuffle
# Cases where the same input is reused many times.
(
mul
(
fx
,
fx
,
fx
,
fx
),
(
fx
,),
(
fxv
,),
1
,
fxv
*
fxv
*
fxv
*
fxv
,
"float32"
,
),
(
mul
(
fx
,
ftanx
,
ftanx
),
(
fx
,),
(
fxv
,),
1
,
fxv
*
np
.
tan
(
fxv
)
*
np
.
tan
(
fxv
),
"float32"
,
),
(
mul
(
fx
,
ftanx
,
ftanx
,
fx
),
(
fx
,),
(
fxv
,),
1
,
fxv
*
np
.
tan
(
fxv
)
*
np
.
tan
(
fxv
)
*
fxv
,
"float32"
,
),
(
mul
(
ftanx
,
ftanx
,
fx
+
fy
),
(
fx
,
fy
),
(
fxv
,
fyv
),
1
,
np
.
tan
(
fxv
)
*
np
.
tan
(
fxv
)
*
(
fxv
+
fyv
),
"float32"
,
),
# 70
# Cases with different broadcast pattern. They should not
# be merged as this would duplicate computation
# The graph should have 2 elemwise and 1 dimshuffle
(
fx
*
sin
(
fs
),
(
fx
,
fs
),
(
fxv
,
fsv
),
3
,
fxv
*
np
.
sin
(
fsv
),
"float32"
,
),
],
)
def
test_elemwise_fusion
(
self
,
case
,
nb_repeat
=
1
,
assert_len_topo
=
True
):
"""Verify that `Elemwise` fusion works."""
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
answer
,
out_dtype
=
case
if
isinstance
(
out_dtype
,
dict
):
out_dtype
=
out_dtype
[
config
.
cast_policy
]
if
self
.
_shared
is
None
:
f
=
function
(
list
(
sym_inputs
),
g
,
mode
=
self
.
mode
)
for
x
in
range
(
nb_repeat
):
out
=
f
(
*
val_inputs
)
else
:
out
=
self
.
_shared
(
np
.
zeros
((
5
,
5
),
dtype
=
out_dtype
),
"out"
)
assert
out
.
dtype
==
g
.
dtype
f
=
function
(
sym_inputs
,
[],
updates
=
[(
out
,
g
)],
mode
=
self
.
mode
)
for
x
in
range
(
nb_repeat
):
f
(
*
val_inputs
)
out
=
out
.
get_value
()
atol
=
1e-8
if
out_dtype
==
"float32"
:
atol
=
1e-6
assert
np
.
allclose
(
out
,
answer
*
nb_repeat
,
atol
=
atol
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo_
=
[
n
for
n
in
topo
if
not
isinstance
(
n
.
op
,
self
.
topo_exclude
)]
if
assert_len_topo
:
assert
len
(
topo_
)
==
nb_elemwise
if
nb_elemwise
==
1
:
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if
len
(
set
(
g
.
owner
.
inputs
))
==
len
(
g
.
owner
.
inputs
):
expected_len_sym_inputs
=
sum
(
not
isinstance
(
x
,
Constant
)
for
x
in
topo_
[
0
]
.
inputs
)
assert
expected_len_sym_inputs
==
len
(
sym_inputs
)
assert
out_dtype
==
out
.
dtype
def
test_fusion_35_inputs
(
self
):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
inpts
=
vectors
([
"i
%
i"
%
i
for
i
in
range
(
35
)])
# Make an elemwise graph looking like:
# sin(i34 + sin(i33 + sin(... i1 + sin(i0) ...)))
out
=
sin
(
inpts
[
0
])
for
idx
in
range
(
1
,
35
):
out
=
sin
(
inpts
[
idx
]
+
out
)
with
config
.
change_flags
(
cxx
=
""
):
f
=
function
(
inpts
,
out
,
mode
=
self
.
mode
)
# Make sure they all weren't fused
composite_nodes
=
[
node
for
node
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
getattr
(
node
.
op
,
"scalar_op"
,
None
),
aes
.
basic
.
Composite
)
]
assert
not
any
(
len
(
node
.
inputs
)
>
31
for
node
in
composite_nodes
)
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"No cxx compiler"
)
def
test_big_fusion
(
self
):
# In the past, pickle of Composite generated in that case
# crashed with max recursion limit. So we were not able to
# generate C code in that case.
factors
=
[]
sd
=
dscalar
()
means
=
dvector
()
cst_05
=
at
.
constant
(
0.5
)
cst_m05
=
at
.
constant
(
-
0.5
)
cst_2
=
at
.
constant
(
2
)
cst_m2
=
at
.
constant
(
-
2
)
ones
=
at
.
constant
(
np
.
ones
(
10
))
n
=
85
if
config
.
mode
in
[
"DebugMode"
,
"DEBUG_MODE"
]:
n
=
10
for
i
in
range
(
n
):
f
=
cst_m05
*
sd
**
cst_m2
*
(
ones
-
means
[
i
])
**
cst_2
+
cst_05
*
log
(
cst_05
*
(
sd
**
cst_m2
)
/
np
.
pi
)
factors
.
append
(
at_sum
(
f
))
logp
=
add
(
*
factors
)
vars
=
[
sd
,
means
]
# Make sure that C compilation is used
mode
=
Mode
(
"cvm"
,
self
.
rewrites
)
dlogp
=
function
(
vars
,
[
aesara
.
grad
(
logp
,
v
)
for
v
in
vars
],
mode
=
mode
)
# Make sure something was fused
assert
any
(
isinstance
(
getattr
(
node
.
op
,
"scalar_op"
,
None
),
aes
.
basic
.
Composite
)
for
node
in
dlogp
.
maker
.
fgraph
.
toposort
()
)
def
test_add_mul_fusion_inplace
(
self
):
rewrites
=
RewriteDatabaseQuery
(
include
=
[
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
"inplace"
,
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
],
)
mode
=
Mode
(
self
.
mode
.
linker
,
rewrites
)
x
,
y
,
z
=
dmatrices
(
"xyz"
)
out
=
dot
(
x
,
y
)
+
x
+
y
+
z
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
mode
)
topo
=
[
n
for
n
in
f
.
maker
.
fgraph
.
toposort
()]
assert
len
(
topo
)
==
2
assert
topo
[
-
1
]
.
op
.
inplace_pattern
new_out
=
f
.
maker
.
fgraph
.
outputs
[
0
]
assert
isinstance
(
new_out
.
owner
.
op
,
Elemwise
)
assert
isinstance
(
new_out
.
owner
.
op
.
scalar_op
,
aes
.
basic
.
Add
)
assert
len
(
new_out
.
owner
.
inputs
)
==
4
# TODO: Do we really need to do this?
_
=
f
(
np
.
random
.
random
((
5
,
5
)),
np
.
random
.
random
((
5
,
5
)),
np
.
random
.
random
((
5
,
5
))
)
@pytest.mark.skipif
(
not
config
.
cxx
,
reason
=
"No cxx compiler"
)
def
test_no_c_code
(
self
):
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
# This custom `Op` has no `c_code` method
class
NoCCodeOp
(
aes
.
basic
.
UnaryScalarOp
):
def
impl
(
self
,
x
):
return
x
*
2
no_c_code_op
=
Elemwise
(
NoCCodeOp
(
aes
.
basic
.
upgrade_to_float
))
mode
=
Mode
(
linker
=
"cvm"
)
mode
.
_optimizer
=
mode
.
_optimizer
.
including
(
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
"inplace"
,
)
x
=
vector
()
out
=
x
*
no_c_code_op
(
x
+
1
)
f
=
function
([
x
],
out
,
mode
=
mode
)
assert
not
any
(
isinstance
(
getattr
(
n
.
op
,
"scalar_op"
),
aes
.
basic
.
Composite
)
for
n
in
f
.
maker
.
fgraph
.
toposort
()
)
@pytest.mark.parametrize
(
"test_value"
,
[
np
.
c_
[[
1.0
]],
np
.
c_
[[]]])
def
test_test_values
(
self
,
test_value
):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""
rewrites
=
RewriteDatabaseQuery
(
include
=
[
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"canonicalize"
,
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
],
)
mode
=
Mode
(
self
.
mode
.
linker
,
rewrites
)
x
,
y
,
z
=
dmatrices
(
"xyz"
)
x
.
tag
.
test_value
=
test_value
y
.
tag
.
test_value
=
test_value
z
.
tag
.
test_value
=
test_value
if
test_value
.
size
==
0
:
cm
=
pytest
.
raises
(
ValueError
)
else
:
cm
=
contextlib
.
suppress
()
with
config
.
change_flags
(
compute_test_value
=
"raise"
,
compute_test_value_opt
=
"raise"
):
out
=
x
*
y
+
z
with
cm
:
f
=
function
([
x
,
y
,
z
],
out
,
mode
=
mode
)
if
test_value
.
size
!=
0
:
# Confirm that the fusion happened
assert
isinstance
(
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
.
scalar_op
,
Composite
)
assert
len
(
f
.
maker
.
fgraph
.
toposort
())
==
1
x_c
,
y_c
,
z_c
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
inputs
assert
np
.
array_equal
(
f
.
maker
.
fgraph
.
outputs
[
0
]
.
tag
.
test_value
,
np
.
c_
[[
2.0
]]
)
class
TimesN
(
aes
.
basic
.
UnaryScalarOp
):
"""
Used in test TestCompositeCodegen
Must be outside of the class, otherwise, the c cache code can't
pickle this class and this cause stuff printing during test.
"""
def
__eq__
(
self
,
other
):
return
super
()
.
__eq__
(
other
)
and
self
.
n
==
other
.
n
def
__hash__
(
self
):
return
super
()
.
__hash__
()
^
hash
(
self
.
n
)
def
__init__
(
self
,
n
,
*
args
,
**
kwargs
):
self
.
n
=
n
aes
.
basic
.
UnaryScalarOp
.
__init__
(
self
,
*
args
,
**
kwargs
)
def
impl
(
self
,
x
):
return
x
*
self
.
n
def
c_support_code_apply
(
self
,
node
,
nodename
):
n
=
str
(
self
.
n
)
return
(
"""
float
%(nodename)
s_timesn(float x) { return x *
%(n)
s; }
"""
%
locals
()
)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
(
x
,)
=
inputs
(
z
,)
=
outputs
return
f
"{z} = {name}_timesn({x});"
class
TestCompositeCodegen
:
"""
Test The Composite Ops code generation in a case where there is multiple
scalar ops with support code.
"""
def
setup_method
(
self
):
upgrade_to_float
=
aes
.
basic
.
upgrade_to_float
self
.
scal_times_2
=
TimesN
(
2
,
upgrade_to_float
,
name
=
"times_2"
)
self
.
times_2
=
Elemwise
(
self
.
scal_times_2
,
name
=
"times_2"
)
self
.
scal_times_3
=
TimesN
(
3
,
upgrade_to_float
,
name
=
"times_3"
)
self
.
times_3
=
Elemwise
(
self
.
scal_times_3
,
name
=
"times_3"
)
self
.
x
=
fvector
()
def
test_nested_composite
(
self
):
y
=
self
.
times_2
(
self
.
x
)
z
=
self
.
times_3
(
y
)
f
=
function
([
self
.
x
],
z
)
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
f
.
maker
.
fgraph
.
toposort
())
==
1
fval
=
f
([
1
,
2
,
3
])
assert
np
.
all
(
fval
==
[
6
,
12
,
18
])
def
test_local_useless_composite
(
self
):
x
=
aes
.
float32
()
c
=
aes
.
Composite
([
x
],
[
x
+
1
,
x
-
1
])
X
=
matrix
()
o
=
Elemwise
(
scalar_op
=
c
)(
X
)
mode
=
get_default_mode
()
.
including
(
"local_useless_composite"
)
f
=
function
([
X
],
o
[
0
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
len
(
topo
[
0
]
.
outputs
)
==
1
utt
.
assert_allclose
(
f
([[
1.0
]]),
[[
2.0
]])
f
=
function
([
X
],
o
[
1
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
len
(
topo
[
0
]
.
outputs
)
==
1
utt
.
assert_allclose
(
f
([[
1.0
]]),
[[
0.0
]])
def
test_local_useless_dimshuffle_makevector
():
a
=
scalar
()
x
=
MakeVector
(
config
.
floatX
)(
a
)
y
=
x
.
dimshuffle
(())
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_useless_dimshuffle_makevector"
],
)
assert
y_rewritten_fg
.
outputs
[
0
]
==
a
tests/tensor/rewriting/test_extra_ops.py
0 → 100644
浏览文件 @
63f52536
import
numpy
as
np
import
pytest
import
aesara.scalar
as
aes
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
OPT_NONE
,
Mode
,
get_default_mode
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.rewriting.utils
import
rewrite_graph
from
aesara.tensor.basic
import
Alloc
,
alloc
,
as_tensor_variable
,
second
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
,
repeat
,
unique
from
aesara.tensor.type
import
dscalar
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_scalar
(
return_index
,
return_counts
,
return_inverse
):
x
=
dscalar
()
y
=
unique
(
x
,
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
None
,
)
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_scalar"
]
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
DimShuffle
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
default_mode
=
get_default_mode
()
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_scalar"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
x_val
=
np
.
array
(
-
10.0
,
dtype
=
np
.
float64
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
()),
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_Alloc_lift
(
x_val
,
axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
alloc
(
x
,
*
new_shape
),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_Alloc_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
Alloc
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
rewrite_mode
=
default_mode
.
excluding
(
"local_useless_alloc"
,
"local_alloc_sink_dimshuffle"
,
"local_Unique_Alloc_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `Alloc` is used to compute the reference `y`
# result
assert
any
(
isinstance
(
node
.
op
,
Alloc
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_BroadcastTo
(
x_val
,
axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
BroadcastTo
()(
x
,
tuple
(
new_shape
)),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_BroadcastTo_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_BroadcastTo_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, unique_axis, repeats, repeat_axis"
,
[
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
1
,
2
),
0
),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_Repeat
(
x_val
,
unique_axis
,
repeats
,
repeat_axis
,
return_index
,
return_counts
,
return_inverse
,
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
repeat
(
x
,
tuple
(
repeats
),
axis
=
repeat_axis
),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
unique_axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_Repeat_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
Repeat
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_Repeat_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
,
Repeat
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, unique_axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
()),
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_second
(
x_val
,
unique_axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
a
=
np
.
zeros
(
tuple
(
new_shape
),
dtype
=
x
.
dtype
)
y
=
unique
(
second
(
a
,
x
),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
unique_axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_second_lift"
],
exclude
=
[
"local_Unique_scalar"
,
"topo_constant_folding"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
y_rewritten_start
=
y_rewritten_start
.
owner
.
inputs
[
0
]
if
y_rewritten_start
.
owner
and
isinstance
(
y_rewritten_start
.
owner
.
op
,
DimShuffle
):
y_rewritten_start
=
y_rewritten_start
.
owner
.
inputs
[
0
]
assert
y_rewritten_start
==
x
assert
not
any
(
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Second
)
for
node
in
y_rewritten_fg
.
apply_nodes
if
isinstance
(
node
.
op
,
Elemwise
)
)
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
Mode
(
optimizer
=
OPT_NONE
))
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Second
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Elemwise
)
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
def
test_local_remove_scalar_BroadcastTo
():
x
=
dscalar
()
y
=
BroadcastTo
()(
x
,
())
assert
isinstance
(
y
.
owner
.
op
,
BroadcastTo
)
res
=
rewrite_graph
(
y
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_remove_scalar_BroadcastTo"
]
)
assert
res
is
x
tests/tensor/rewriting/test_math.py
浏览文件 @
63f52536
...
...
@@ -79,7 +79,7 @@ from aesara.tensor.math import round as at_round
from
aesara.tensor.math
import
sgn
,
sigmoid
,
sin
,
sinh
,
softplus
,
sqr
,
sqrt
,
sub
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
tan
,
tanh
,
true_div
,
xor
from
aesara.tensor.rewriting.
basic
import
local_dimshuffle_lift
from
aesara.tensor.rewriting.
elemwise
import
local_dimshuffle_lift
from
aesara.tensor.rewriting.math
import
(
compute_mul
,
is_1pexp
,
...
...
tests/tensor/rewriting/test_shape.py
0 → 100644
浏览文件 @
63f52536
import
copy
import
numpy
as
np
import
pytest
import
aesara.tensor
as
at
from
aesara
import
shared
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
get_default_mode
,
get_mode
from
aesara.compile.ops
import
deep_copy_op
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.rewriting.basic
import
check_stack_trace
,
node_rewriter
,
out2in
from
aesara.graph.rewriting.utils
import
rewrite_graph
from
aesara.graph.type
import
Type
from
aesara.tensor.basic
import
as_tensor_variable
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.math
import
add
,
exp
,
maximum
from
aesara.tensor.rewriting.basic
import
register_specialize
from
aesara.tensor.rewriting.shape
import
(
ShapeFeature
,
local_reshape_to_dimshuffle
,
local_useless_reshape
,
)
from
aesara.tensor.shape
import
(
Reshape
,
Shape_i
,
SpecifyShape
,
reshape
,
shape
,
specify_shape
,
)
from
aesara.tensor.subtensor
import
set_subtensor
from
aesara.tensor.type
import
(
fmatrix
,
iscalar
,
lscalar
,
matrix
,
scalar
,
tensor
,
tensor3
,
tensor4
,
vector
,
)
from
tests
import
unittest_tools
as
utt
rewrite_mode
=
config
.
mode
if
rewrite_mode
==
"FAST_COMPILE"
:
rewrite_mode
=
"FAST_RUN"
rewrite_mode
=
get_mode
(
rewrite_mode
)
class
TestShapeRewriter
:
def
test_basic
(
self
):
mode
=
config
.
mode
if
mode
==
"FAST_COMPILE"
:
mode
=
"FAST_RUN"
v
=
vector
()
m
=
matrix
()
f
=
function
([
v
,
m
],
(
v
+
m
)
.
shape
,
mode
=
mode
)
for
node
in
f
.
maker
.
fgraph
.
toposort
():
assert
node
.
op
!=
add
def
test_constant
(
self
):
mode
=
config
.
mode
if
mode
==
"FAST_COMPILE"
:
mode
=
"FAST_RUN"
v
=
vector
()
f
=
function
([
v
],
v
.
dimshuffle
(
"x"
,
"x"
,
0
)
.
shape
[
1
],
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
topo
)
==
1
assert
topo
[
0
]
.
op
==
deep_copy_op
@staticmethod
def
max_pool_c01b
(
c01b
,
pool_shp
,
pool_stride
,
img_shp
):
"""
Like max_pool but with input using axes ('c', 0, 1, 'b')
(Alex Krizhevsky format)
pool_shp, pool_stride and img_shp are int that represent
the same shp in x and y.
"""
mx
=
None
# Compute index in pooled space of last needed pool
# (needed = each input pixel must appear in at least one pool)
def
last_pool
(
im_shp
,
p_shp
,
p_strd
):
rval
=
int
(
np
.
ceil
(
float
(
im_shp
-
p_shp
)
/
p_strd
))
assert
p_strd
*
rval
+
p_shp
>=
im_shp
assert
p_strd
*
(
rval
-
1
)
+
p_shp
<
im_shp
return
rval
# Compute starting row of the last pool
last_pool_r
=
last_pool
(
img_shp
,
pool_shp
,
pool_stride
)
*
pool_stride
# Compute number of rows needed in img for all indexes to work out
required_r
=
last_pool_r
+
pool_shp
last_pool_c
=
last_pool
(
img_shp
,
pool_shp
,
pool_stride
)
*
pool_stride
required_c
=
last_pool_c
+
pool_shp
wide_infinity
=
at
.
alloc
(
-
np
.
inf
,
c01b
.
shape
[
0
],
required_r
,
required_c
,
c01b
.
shape
[
3
]
)
c01b
=
set_subtensor
(
wide_infinity
[:,
0
:
img_shp
,
0
:
img_shp
,
:],
c01b
)
for
row_within_pool
in
range
(
pool_shp
):
row_stop
=
last_pool_r
+
row_within_pool
+
1
for
col_within_pool
in
range
(
pool_shp
):
col_stop
=
last_pool_c
+
col_within_pool
+
1
cur
=
c01b
[
:,
row_within_pool
:
row_stop
:
pool_stride
,
col_within_pool
:
col_stop
:
pool_stride
,
:,
]
if
mx
is
None
:
mx
=
cur
else
:
mx
=
maximum
(
mx
,
cur
)
return
mx
def
test_broadcasted_dims
(
self
):
# This test a case that caused a crash during rewriting
shp
=
(
1
,
1
,
1
,
1
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
a
=
shared
(
rng
.
random
(
shp
)
.
astype
(
config
.
floatX
))
out
=
self
.
max_pool_c01b
(
a
,
1
,
1
,
1
)
# max_pool_c01b use -inf and this will trigger DebugMode error.
mode
=
copy
.
copy
(
get_default_mode
())
mode
.
check_isfinite
=
False
f
=
function
([],
out
,
mode
=
mode
)
f
()
def
test_constant_merge
(
self
):
# This test the error in gh-1122 that is a caused by the
# combination of merge rewriter and ShapeFeature.
x
=
at
.
constant
([
0
,
0
])
y
=
x
[
1
:]
x1
=
x
-
at
.
join
(
0
,
y
,
y
)
x1
.
eval
()
def
test_local_track_shape_i
(
self
):
class
IdentityNoShape
(
Op
):
"""Op that does not infer the output shape from the input one"""
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
out
[
0
]
=
x
.
copy
()
# def infer_shape(self, fgraph, node, (xshp,)):
# return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])]
identity_noshape
=
IdentityNoShape
()
class
IdentityShape
(
Op
):
"""Op that does infer the output shape from the input one"""
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
out
[
0
]
=
x
.
copy
()
def
infer_shape
(
self
,
fgraph
,
node
,
xshp_
):
# Could also just return.
(
xshp
,)
=
xshp_
return
(
xshp
,)
identity_shape
=
IdentityShape
()
@node_rewriter
([
IdentityNoShape
])
def
local_identity_noshape_to_identity_shape
(
fgraph
,
node
):
"""Transform the first `Op` into the second."""
if
isinstance
(
node
.
op
,
IdentityNoShape
):
return
[
identity_shape
(
node
.
inputs
[
0
])]
mode
=
get_default_mode
()
.
including
(
"ShapeOpt"
,
"specialize"
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
x
=
tensor3
(
"x"
)
ins_x
=
identity_noshape
(
x
)
# Without the rewrite
f
=
function
([
x
],
ins_x
.
shape
,
mode
=
mode
)
xval
=
rng
.
standard_normal
((
3
,
4
,
7
))
.
astype
(
config
.
floatX
)
assert
np
.
all
(
f
(
xval
)
==
[
3
,
4
,
7
])
f_ops
=
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
assert
len
(
f_ops
)
==
5
assert
identity_noshape
in
f_ops
assert
identity_shape
not
in
f_ops
# Register the rewrite
register_specialize
(
local_identity_noshape_to_identity_shape
)
mode
=
get_default_mode
()
.
including
(
"ShapeOpt"
,
"specialize"
)
# The `identity_shape` hOph should not be needed anymore to compute
# the shape
g
=
function
([
x
],
ins_x
.
shape
,
mode
=
mode
)
xval
=
rng
.
standard_normal
((
6
,
1
,
2
))
.
astype
(
config
.
floatX
)
assert
np
.
all
(
g
(
xval
)
==
[
6
,
1
,
2
])
g_ops
=
[
node
.
op
for
node
in
g
.
maker
.
fgraph
.
toposort
()]
assert
len
(
g_ops
)
==
4
assert
identity_noshape
not
in
g_ops
assert
identity_shape
not
in
g_ops
# Test multiple applications of an `Op` without an `Op.infer_shape`
ins_x3
=
identity_noshape
(
identity_noshape
(
identity_noshape
(
x
)))
h
=
function
([
x
],
ins_x3
.
shape
,
mode
=
mode
)
xval
=
rng
.
standard_normal
((
6
,
1
,
2
))
.
astype
(
config
.
floatX
)
assert
np
.
all
(
h
(
xval
)
==
[
6
,
1
,
2
])
h_ops
=
[
node
.
op
for
node
in
h
.
maker
.
fgraph
.
toposort
()]
assert
len
(
h_ops
)
==
4
assert
identity_noshape
not
in
h_ops
assert
identity_shape
not
in
h_ops
def
test_no_shapeopt
(
self
):
"""Test that a basic example works even when `ShapeOpt` is excluded."""
X
=
matrix
()
expr
=
X
.
shape
[
0
]
mode
=
get_default_mode
()
.
excluding
(
"ShapeOpt"
)
f
=
function
([
X
],
expr
,
mode
=
mode
)
# FIXME: This is not a good test.
f
([[
1
,
2
],
[
2
,
3
]])
class
TestReshape
:
def
setup_method
(
self
):
self
.
mode
=
rewrite_mode
self
.
op
=
Reshape
def
test_local_reshape
(
self
):
a
=
fmatrix
()
b
=
self
.
op
(
3
)(
a
,
[
2
,
3
,
4
])
c
=
self
.
op
(
1
)(
b
,
[
24
])
f
=
function
([
a
],
c
,
mode
=
self
.
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
sum
(
isinstance
(
node
.
op
,
self
.
op
)
for
node
in
topo
)
==
1
# Check stack trace
assert
check_stack_trace
(
f
,
ops_to_check
=
[
self
.
op
])
class
TestLocalUselessReshape
:
def
setup_method
(
self
):
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
test_0
(
self
):
mode
=
get_default_mode
()
.
including
(
"local_useless_reshape"
)
i
=
iscalar
(
"i"
)
m
=
at
.
mgrid
[
0
:
i
,
]
f
=
function
([
i
],
m
,
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
def
test_1
(
self
):
x
=
matrix
(
"x"
)
r
=
x
.
reshape
(
x
.
shape
)
m0
=
get_default_mode
()
m1
=
m0
.
including
(
"local_useless_reshape"
)
f1
=
function
([
x
],
r
,
mode
=
m1
)
topo
=
f1
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
m2
=
m1
.
excluding
(
"ShapeOpt"
)
f2
=
function
([
x
],
r
,
mode
=
m2
)
topo
=
f2
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
# We do not need tests checking that stack traces are copied over,
# because local_useless_reshape only removes nodes from the graph
def
test_2
(
self
):
x
=
matrix
(
"x"
)
r
=
x
.
reshape
([
Shape_i
(
i
)(
x
)
for
i
in
range
(
x
.
ndim
)])
m0
=
get_default_mode
()
m1
=
m0
.
including
(
"local_useless_reshape"
)
f1
=
function
([
x
],
r
,
mode
=
m1
)
topo
=
f1
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
m2
=
m1
.
excluding
(
"ShapeOpt"
)
f2
=
function
([
x
],
r
,
mode
=
m2
)
topo
=
f2
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
def
test_m1
(
self
):
x
=
matrix
(
"x"
)
r
=
x
.
reshape
((
x
.
shape
[
0
],
-
1
))
m0
=
get_default_mode
()
m1
=
m0
.
including
(
"local_useless_reshape"
)
f1
=
function
([
x
],
r
,
mode
=
m1
)
topo
=
f1
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
m2
=
m1
.
excluding
(
"ShapeOpt"
)
f2
=
function
([
x
],
r
,
mode
=
m2
)
topo
=
f2
.
maker
.
fgraph
.
toposort
()
assert
not
any
(
isinstance
(
n
.
op
,
Reshape
)
for
n
in
topo
)
class
TestLocalReshapeToDimshuffle
:
def
setup_method
(
self
):
self
.
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
def
test_1
(
self
):
reshape_lift
=
out2in
(
local_reshape_to_dimshuffle
)
useless_reshape
=
out2in
(
local_useless_reshape
)
x
=
shared
(
self
.
rng
.
standard_normal
((
4
,)))
y
=
shared
(
self
.
rng
.
standard_normal
((
5
,
6
)))
reshape_x
=
reshape
(
x
,
(
1
,
4
))
reshape_y
=
reshape
(
y
,
(
1
,
5
,
1
,
6
,
1
,
1
))
g
=
FunctionGraph
([
x
,
y
],
[
reshape_x
,
reshape_y
])
assert
str
(
g
)
==
(
"FunctionGraph(Reshape{2}"
"(<TensorType(float64, (None,))>, "
"TensorConstant{[1 4]}), "
"Reshape{6}"
"(<TensorType(float64, (None, None))>, "
"TensorConstant{[1 5 1 6 1 1]}))"
)
reshape_lift
.
rewrite
(
g
)
useless_reshape
.
rewrite
(
g
)
assert
str
(
g
)
==
(
"FunctionGraph(InplaceDimShuffle{x,0}"
"(<TensorType(float64, (None,))>), "
"InplaceDimShuffle{x,0,x,1,x,x}"
"(Reshape{2}(<TensorType(float64, (None, None))>, "
"TensorConstant{[5 6]})))"
)
# Check stacktrace was copied over correctly after the rewrite was applied
assert
check_stack_trace
(
g
,
ops_to_check
=
(
DimShuffle
,
Reshape
))
def
test_local_reshape_lift
():
x
=
tensor4
()
out
=
exp
(
x
)
.
reshape
([
x
.
size
])
assert
out
.
ndim
==
1
mode
=
get_default_mode
()
mode
=
mode
.
including
(
"local_reshape_lift"
)
f
=
function
([
x
],
out
,
mode
=
mode
)
f
(
np
.
random
.
random
((
5
,
4
,
3
,
2
))
.
astype
(
config
.
floatX
))
topo
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
topo
[
-
2
]
.
op
,
Reshape
)
assert
isinstance
(
topo
[
-
1
]
.
op
,
Elemwise
)
assert
check_stack_trace
(
f
,
ops_to_check
=
"last"
)
class
TestShapeI
(
utt
.
InferShapeTester
):
def
setup_method
(
self
):
super
()
.
setup_method
()
def
test_perform
(
self
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
advec
=
vector
()
advec_val
=
rng
.
random
((
3
))
.
astype
(
config
.
floatX
)
f
=
function
([
advec
],
Shape_i
(
0
)(
advec
))
out
=
f
(
advec_val
)
utt
.
assert_allclose
(
out
,
advec_val
.
shape
[
0
])
admat
=
matrix
()
admat_val
=
rng
.
random
((
4
,
3
))
.
astype
(
config
.
floatX
)
for
i
in
range
(
2
):
f
=
function
([
admat
],
Shape_i
(
i
)(
admat
))
out
=
f
(
admat_val
)
utt
.
assert_allclose
(
out
,
admat_val
.
shape
[
i
])
def
test_infer_shape
(
self
):
admat
=
matrix
()
admat_val
=
np
.
random
.
random
((
3
,
4
))
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
([
admat
],
[
Shape_i
(
0
)(
admat
)],
[
admat_val
],
Shape_i
)
self
.
_compile_and_check
([
admat
],
[
Shape_i
(
1
)(
admat
)],
[
admat_val
],
Shape_i
)
class
TestSameShape
:
def
test_scalar
(
self
):
x
=
scalar
()
cst
=
at
.
constant
(
1
)
o
=
x
+
cst
fgraph
=
FunctionGraph
([
x
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
assert
shape_feature
.
same_shape
(
x
,
o
)
def
test_vector
(
self
):
x
=
vector
()
cst
=
at
.
constant
(
1
)
o
=
x
+
cst
fgraph
=
FunctionGraph
([
x
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
assert
shape_feature
.
same_shape
(
x
,
o
)
def
test_no_static_shapes
(
self
):
x
=
vector
()
y
=
vector
()
o
=
x
+
y
fgraph
=
FunctionGraph
([
x
,
y
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
# We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert
not
shape_feature
.
same_shape
(
x
,
o
)
# The following case isn't implemented
assert
not
shape_feature
.
same_shape
(
y
,
o
)
@pytest.mark.parametrize
(
"y_dim_0"
,
[
2
,
pytest
.
param
(
None
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"Not implemented"
))],
)
def
test_vector_dim
(
self
,
y_dim_0
):
x
=
at
.
tensor
(
dtype
=
"floatX"
,
shape
=
(
2
,
None
))
y
=
at
.
tensor
(
dtype
=
"floatX"
,
shape
=
(
y_dim_0
,
None
))
o
=
x
+
y
fgraph
=
FunctionGraph
([
x
,
y
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
assert
shape_feature
.
same_shape
(
x
,
o
,
0
,
0
)
assert
not
shape_feature
.
same_shape
(
x
,
o
,
1
,
1
)
def
test_vector_dim_err
(
self
):
x
=
vector
()
y
=
vector
()
o
=
x
+
y
fgraph
=
FunctionGraph
([
x
,
y
],
[
o
],
clone
=
False
)
shape_feature
=
ShapeFeature
()
fgraph
.
attach_feature
(
shape_feature
)
with
pytest
.
raises
(
IndexError
):
shape_feature
.
same_shape
(
x
,
o
,
1
,
0
)
with
pytest
.
raises
(
IndexError
):
shape_feature
.
same_shape
(
x
,
o
,
0
,
1
)
@pytest.mark.parametrize
(
"shape"
,
[
lscalar
(),
iscalar
()],
)
def
test_local_Shape_of_SpecifyShape
(
shape
):
x
=
vector
()
s
=
specify_shape
(
x
,
shape
)
.
shape
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
x
not
in
fgraph
.
variables
assert
shape
in
fgraph
.
variables
@pytest.mark.parametrize
(
"s1"
,
[
lscalar
(),
iscalar
()],
)
def
test_local_Shape_of_SpecifyShape_partial
(
s1
):
x
=
matrix
()
s
=
specify_shape
(
x
,
(
s1
,
None
))
.
shape
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
assert
any
(
isinstance
(
apply
.
op
,
SpecifyShape
)
for
apply
in
fgraph
.
apply_nodes
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
x
in
fgraph
.
variables
assert
s1
in
fgraph
.
variables
assert
not
any
(
isinstance
(
apply
.
op
,
SpecifyShape
)
for
apply
in
fgraph
.
apply_nodes
)
def
test_local_Shape_i_of_broadcastable
():
x
=
tensor
(
np
.
float64
,
[
False
,
True
])
s
=
Shape_i
(
1
)(
x
)
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
x
not
in
fgraph
.
variables
assert
fgraph
.
outputs
[
0
]
.
data
==
1
# A test for a non-`TensorType`
class
MyType
(
Type
):
ndim
=
1
def
filter
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
class
MyVariable
(
Variable
):
pass
x
=
MyVariable
(
MyType
(),
None
,
None
)
s
=
Shape_i
(
0
)(
x
)
fgraph
=
FunctionGraph
(
outputs
=
[
s
],
clone
=
False
)
_
=
rewrite_graph
(
fgraph
,
clone
=
False
)
assert
fgraph
.
outputs
[
0
]
==
s
def
test_Shape_i_canonicalize
():
"""Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension.
In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)``
and nothing else. The rewrites `local_shape_to_shape_i`,
`local_subtensor_remove_broadcastable_index`, and
`local_useless_dimshuffle_makevector` need to work together to accomplish
this, and we confirm that here.
"""
x
=
vector
()
y
=
shape
(
x
)[
0
]
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
,
features
=
[
ShapeFeature
()])
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
assert
isinstance
(
y_rewritten
.
owner
.
op
,
Shape_i
)
assert
y_rewritten
.
owner
.
op
.
i
==
0
assert
y_rewritten
.
owner
.
inputs
[
0
]
==
x
tests/tensor/test_elemwise.py
浏览文件 @
63f52536
...
...
@@ -18,9 +18,9 @@ from aesara.link.c.basic import CLinker, OpWiseCLinker
from
aesara.tensor
import
as_tensor_variable
from
aesara.tensor.basic
import
second
from
aesara.tensor.elemwise
import
CAReduce
,
CAReduceDtype
,
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
ShapeError
from
aesara.tensor.math
import
all
as
at_all
from
aesara.tensor.math
import
any
as
at_any
from
aesara.tensor.rewriting.basic
import
ShapeError
from
aesara.tensor.type
import
(
TensorType
,
bmatrix
,
...
...
tests/tensor/test_shape.py
浏览文件 @
63f52536
...
...
@@ -12,7 +12,7 @@ from aesara.misc.safe_asarray import _asarray
from
aesara.tensor
import
as_tensor_variable
,
get_vector_length
,
row
from
aesara.tensor.basic
import
MakeVector
,
constant
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.rewriting.
basic
import
ShapeFeature
from
aesara.tensor.rewriting.
shape
import
ShapeFeature
from
aesara.tensor.shape
import
(
Reshape
,
Shape_i
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论