Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
13ebc731
提交
13ebc731
authored
9月 09, 2021
作者:
Neel Iyer
提交者:
Brandon T. Willard
9月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move subtensor rewrites from basic opt to subtensor opt
上级
c42b56ab
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
1384 行增加
和
1350 行删除
+1384
-1350
basic_opt.py
aesara/tensor/basic_opt.py
+3
-1347
subtensor_opt.py
aesara/tensor/subtensor_opt.py
+1381
-3
没有找到文件。
aesara/tensor/basic_opt.py
浏览文件 @
13ebc731
...
@@ -30,7 +30,6 @@ from aesara.graph.op import get_test_value
...
@@ -30,7 +30,6 @@ from aesara.graph.op import get_test_value
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
GlobalOptimizer
,
GlobalOptimizer
,
OpRemove
,
OpRemove
,
TopoOptimizer
,
check_chain
,
check_chain
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
...
@@ -46,7 +45,6 @@ from aesara.printing import pprint
...
@@ -46,7 +45,6 @@ from aesara.printing import pprint
from
aesara.tensor.basic
import
(
from
aesara.tensor.basic
import
(
Alloc
,
Alloc
,
AllocEmpty
,
AllocEmpty
,
ARange
,
Flatten
,
Flatten
,
Join
,
Join
,
MakeVector
,
MakeVector
,
...
@@ -76,37 +74,11 @@ from aesara.tensor.basic import (
...
@@ -76,37 +74,11 @@ from aesara.tensor.basic import (
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.elemwise
import
DimShuffle
,
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
,
ShapeError
from
aesara.tensor.exceptions
import
NotScalarConstantError
,
ShapeError
from
aesara.tensor.extra_ops
import
broadcast_shape
from
aesara.tensor.extra_ops
import
broadcast_shape
from
aesara.tensor.math
import
Dot
,
add
from
aesara.tensor.math
import
eq
from
aesara.tensor.math
import
all
as
aet_all
from
aesara.tensor.math
import
(
and_
,
ceil_intdiv
,
dot
,
eq
,
ge
,
gt
,
le
,
lt
,
maximum
,
minimum
,
or_
,
)
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
shape
,
shape_padleft
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
shape
,
shape_padleft
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.sort
import
TopKOp
from
aesara.tensor.subtensor
import
(
from
aesara.tensor.subtensor
import
Subtensor
,
get_idx_list
AdvancedIncSubtensor
,
from
aesara.tensor.type
import
discrete_dtypes
,
integer_dtypes
,
lscalar
AdvancedIncSubtensor1
,
AdvancedSubtensor1
,
IncSubtensor
,
Subtensor
,
advanced_inc_subtensor1
,
advanced_subtensor
,
advanced_subtensor1
,
as_index_constant
,
get_canonical_form_slice
,
get_idx_list
,
)
from
aesara.tensor.type
import
TensorType
,
discrete_dtypes
,
integer_dtypes
,
lscalar
from
aesara.tensor.var
import
TensorConstant
from
aesara.tensor.var
import
TensorConstant
from
aesara.utils
import
NoDuplicateOptWarningFilter
from
aesara.utils
import
NoDuplicateOptWarningFilter
...
@@ -1947,174 +1919,6 @@ def local_track_shape_i(fgraph, node):
...
@@ -1947,174 +1919,6 @@ def local_track_shape_i(fgraph, node):
return
[
shape_feature
.
shape_of
[
replacement
][
node
.
op
.
i
]]
return
[
shape_feature
.
shape_of
[
replacement
][
node
.
op
.
i
]]
@register_specialize
@register_canonicalize
@local_optimizer
([
Subtensor
])
def
local_subtensor_inc_subtensor
(
fgraph
,
node
):
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
"""
if
isinstance
(
node
.
op
,
Subtensor
):
x
=
node
.
inputs
[
0
]
if
not
x
.
owner
or
not
isinstance
(
x
.
owner
.
op
,
IncSubtensor
):
return
if
not
x
.
owner
.
op
.
set_instead_of_inc
:
return
if
x
.
owner
.
inputs
[
2
:]
==
node
.
inputs
[
1
:]
and
tuple
(
x
.
owner
.
op
.
idx_list
)
==
tuple
(
node
.
op
.
idx_list
):
out
=
node
.
outputs
[
0
]
y
=
x
.
owner
.
inputs
[
1
]
# If the dtypes differ, cast y into x.dtype
if
x
.
dtype
!=
y
.
dtype
:
y
=
y
.
astype
(
x
.
dtype
)
if
out
.
type
==
y
.
type
:
# if x[idx] and y have the same type, directly return y
return
[
y
]
else
:
# The difference is related to broadcasting pattern
assert
out
.
broadcastable
!=
y
.
broadcastable
# We have to alloc y to the shape of x[idx]
x_subtensor
=
node
.
op
(
x
.
owner
.
inputs
[
0
],
*
x
.
owner
.
inputs
[
2
:])
return
[
alloc
(
y
,
*
x_subtensor
.
shape
)]
else
:
return
@register_specialize
@register_canonicalize
@local_optimizer
([
Subtensor
])
def
local_subtensor_remove_broadcastable_index
(
fgraph
,
node
):
"""
Remove broadcastable dimension with index 0 or -1
a[:,:,:,0] -> a.dimshuffle(0,1,2), when
a.broadcastable = (False, False, False, True)
a[0,:,-1,:] -> a.dimshuffle(1,3), when
a.broadcastable = (True, False, True, False)
"""
if
isinstance
(
node
.
op
,
Subtensor
):
idx
=
node
.
op
.
idx_list
else
:
return
remove_dim
=
[]
node_inputs_idx
=
1
for
dim
,
elem
in
enumerate
(
idx
):
if
isinstance
(
elem
,
(
aes
.
Scalar
)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index
=
node
.
inputs
[
node_inputs_idx
]
if
type
(
dim_index
)
==
aes
.
ScalarConstant
:
dim_index
=
dim_index
.
value
if
dim_index
in
[
0
,
-
1
]
and
node
.
inputs
[
0
]
.
broadcastable
[
dim
]:
remove_dim
.
append
(
dim
)
node_inputs_idx
+=
1
else
:
return
elif
isinstance
(
elem
,
slice
):
if
elem
!=
slice
(
None
):
return
elif
isinstance
(
elem
,
(
int
,
np
.
integer
)):
if
elem
in
[
0
,
-
1
]
and
node
.
inputs
[
0
]
.
broadcastable
[
dim
]:
remove_dim
.
append
(
dim
)
else
:
raise
TypeError
(
"case not expected"
)
if
len
(
remove_dim
)
==
0
:
return
else
:
all_dim
=
range
(
node
.
inputs
[
0
]
.
ndim
)
remain_dim
=
[
x
for
x
in
all_dim
if
x
not
in
remove_dim
]
return
[
node
.
inputs
[
0
]
.
dimshuffle
(
tuple
(
remain_dim
))]
@register_specialize
@register_canonicalize
(
"fast_compile_gpu"
)
@register_useless
@local_optimizer
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
fgraph
,
node
):
"""
Replace all subtensor(make_vector) like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes.
"""
x
=
node
.
inputs
[
0
]
if
not
x
.
owner
or
x
.
owner
.
op
!=
make_vector
:
return
if
isinstance
(
node
.
op
,
Subtensor
):
# This optimization needs ShapeOpt and fgraph.shape_feature
try
:
(
idx
,)
=
node
.
op
.
idx_list
except
Exception
:
# 'how can you have multiple indexes into a shape?'
raise
if
isinstance
(
idx
,
(
aes
.
Scalar
,
TensorType
)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
==
old_idx
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
else
:
return
if
isinstance
(
idx
,
(
int
,
np
.
integer
)):
# We don't need to copy over any stack traces here
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
# if it is a constant we can do something with it
try
:
v
=
get_scalar_constant_value
(
idx
,
only_process_constants
=
True
)
if
isinstance
(
v
,
np
.
integer
):
# Python 2.4 wants to index only with Python integers
v
=
int
(
v
)
# We don't need to copy over any stack traces here
try
:
ret
=
[
x
.
owner
.
inputs
[
v
]]
except
IndexError
:
raise
NotScalarConstantError
(
"Bad user graph!"
)
return
ret
except
NotScalarConstantError
:
pass
elif
idx
.
ndim
==
1
and
isinstance
(
idx
,
Constant
):
values
=
list
(
map
(
int
,
list
(
idx
.
value
)))
ret
=
make_vector
(
*
[
x
.
owner
.
inputs
[
v
]
for
v
in
values
])
# Copy over stack trace from previous output to new output
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
ret
=
patternbroadcast
(
ret
,
node
.
outputs
[
0
]
.
broadcastable
)
return
[
ret
]
else
:
raise
TypeError
(
"case not expected"
)
elif
isinstance
(
idx
,
slice
):
# it is a slice of ints and/or Variables
# check subtensor to see if it can contain constant variables, and if
# it can, then try to unpack them.
try
:
const_slice
=
node
.
op
.
get_constant_idx
(
node
.
inputs
,
allow_partial
=
False
)[
0
]
ret
=
make_vector
(
*
x
.
owner
.
inputs
[
const_slice
])
# Copy over stack trace from previous outputs to new output
copy_stack_trace
(
node
.
outputs
,
ret
)
ret
=
patternbroadcast
(
ret
,
node
.
outputs
[
0
]
.
broadcastable
)
return
[
ret
]
except
NotScalarConstantError
:
pass
else
:
raise
TypeError
(
"case not expected"
)
# TODO: the other optimization for and, or, xor, le and ge see ticket #496.
# TODO: the other optimization for and, or, xor, le and ge see ticket #496.
...
@@ -2468,1154 +2272,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
...
@@ -2468,1154 +2272,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
return
rval
return
rval
##################
# Subtensor opts #
##################
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer
([
IncSubtensor
])
def
local_useless_inc_subtensor
(
fgraph
,
node
):
"""
Remove IncSubtensor, when we overwrite the full inputs with the
new value.
"""
if
not
isinstance
(
node
.
op
,
IncSubtensor
):
return
if
node
.
op
.
set_instead_of_inc
is
False
:
# This is an IncSubtensor, so the init value must be zeros
try
:
c
=
get_scalar_constant_value
(
node
.
inputs
[
0
],
only_process_constants
=
True
)
if
c
!=
0
:
return
except
NotScalarConstantError
:
return
if
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
or
node
.
inputs
[
0
]
.
broadcastable
!=
node
.
inputs
[
1
]
.
broadcastable
):
# FB: I didn't check if this case can happen, but this opt
# don't support it.
return
# We have a SetSubtensor or an IncSubtensor on zeros
# If is this IncSubtensor useful?
# Check that we keep all the original data.
# Put the constant inputs in the slice.
idx_cst
=
get_idx_list
(
node
.
inputs
[
1
:],
node
.
op
.
idx_list
)
if
all
(
isinstance
(
e
,
slice
)
and
e
.
start
is
None
and
e
.
stop
is
None
and
(
e
.
step
is
None
or
extract_constant
(
e
.
step
,
only_process_constants
=
True
)
==
-
1
)
for
e
in
idx_cst
):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same.
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
if
not
fgraph
.
shape_feature
.
same_shape
(
node
.
inputs
[
0
],
node
.
inputs
[
1
]):
return
# There is no reverse, so we don't need a replacement.
if
all
(
e
.
step
is
None
for
e
in
node
.
op
.
idx_list
):
# They are the same shape, so we can remove this IncSubtensor
return
[
node
.
inputs
[
1
]]
ret
=
Subtensor
(
node
.
op
.
idx_list
)(
*
node
.
inputs
[
1
:])
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
@register_canonicalize
@register_specialize
@local_optimizer
([
AdvancedIncSubtensor1
])
def
local_set_to_inc_subtensor
(
fgraph
,
node
):
r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it
did this wouldn't need to also be included in the "specialize" pass.
"""
if
(
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
node
.
op
.
set_instead_of_inc
and
node
.
inputs
[
1
]
.
owner
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
Elemwise
)
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
.
scalar_op
,
aes
.
Add
)
):
addn
=
node
.
inputs
[
1
]
.
owner
subn
=
None
other
=
None
if
addn
.
inputs
[
0
]
.
owner
and
isinstance
(
addn
.
inputs
[
0
]
.
owner
.
op
,
AdvancedSubtensor1
):
subn
=
addn
.
inputs
[
0
]
.
owner
other
=
addn
.
inputs
[
1
]
elif
addn
.
inputs
[
1
]
.
owner
and
isinstance
(
addn
.
inputs
[
1
]
.
owner
.
op
,
AdvancedSubtensor1
):
subn
=
addn
.
inputs
[
1
]
.
owner
other
=
addn
.
inputs
[
0
]
else
:
return
if
subn
.
inputs
[
1
]
!=
node
.
inputs
[
2
]
or
subn
.
inputs
[
0
]
!=
node
.
inputs
[
0
]:
return
ret
=
advanced_inc_subtensor1
(
node
.
inputs
[
0
],
other
,
node
.
inputs
[
2
])
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_useless_slice
(
fgraph
,
node
):
"""
Remove Subtensor of the form X[0, :] -> X[0]
"""
if
isinstance
(
node
.
op
,
Subtensor
):
slices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
last_slice
=
len
(
slices
)
for
s
in
slices
[::
-
1
]:
# check if slice and then check slice indices
if
(
isinstance
(
s
,
slice
)
and
s
.
start
is
None
and
s
.
stop
is
None
and
(
s
.
step
is
None
or
extract_constant
(
s
.
step
,
only_process_constants
=
True
)
==
1
)
):
last_slice
-=
1
else
:
break
# check if we removed something
if
last_slice
<
len
(
slices
):
subtens
=
Subtensor
(
slices
[:
last_slice
])
sl_ins
=
Subtensor
.
collapse
(
slices
[:
last_slice
],
lambda
x
:
isinstance
(
x
,
Variable
)
)
out
=
subtens
(
node
.
inputs
[
0
],
*
sl_ins
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
out
)
return
[
out
]
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
,
AdvancedSubtensor1
])
def
local_useless_subtensor
(
fgraph
,
node
):
"""
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
shape_of
=
fgraph
.
shape_feature
.
shape_of
if
isinstance
(
node
.
op
,
Subtensor
):
cdata
=
node
.
op
.
get_constant_idx
(
node
.
inputs
,
allow_partial
=
True
,
only_process_constants
=
True
)
for
pos
,
idx
in
enumerate
(
cdata
):
if
not
isinstance
(
idx
,
slice
):
# If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless
return
False
if
idx
.
start
is
not
None
and
idx
.
start
!=
0
:
# If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless
return
False
if
idx
.
step
is
not
None
and
idx
.
step
!=
1
:
# If we are going backwards, or skipping elements, then this
# is not a useless subtensor
return
False
for
pos
,
idx
in
enumerate
(
cdata
):
length_pos
=
shape_of
[
node
.
inputs
[
0
]][
pos
]
if
isinstance
(
idx
.
stop
,
(
int
,
np
.
integer
)):
length_pos_data
=
sys
.
maxsize
try
:
length_pos_data
=
get_scalar_constant_value
(
length_pos
,
only_process_constants
=
True
)
except
NotScalarConstantError
:
pass
if
idx
.
stop
<
length_pos_data
:
return
False
elif
isinstance
(
idx
.
stop
,
Variable
):
length_pos_shape_i
=
idx
.
stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
if
length_pos_shape_i
.
owner
and
isinstance
(
length_pos_shape_i
.
owner
.
op
,
ScalarFromTensor
):
length_pos_shape_i
=
length_pos_shape_i
.
owner
.
inputs
[
0
]
elif
length_pos
.
owner
and
isinstance
(
length_pos
.
owner
.
op
,
TensorFromScalar
):
length_pos
=
length_pos
.
owner
.
inputs
[
0
]
else
:
# We did not find underlying variables of the same type
return
False
# The type can be different: int32 vs int64. length_pos
# should always be int64 as that is what the shape
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
# as index type.
assert
str
(
length_pos
.
type
.
dtype
)
==
"int64"
assert
str
(
length_pos_shape_i
.
type
.
dtype
)
in
[
"int8"
,
"int16"
,
"int32"
,
"int64"
,
]
# length_pos_shape_i cannot be None
if
length_pos_shape_i
!=
length_pos
:
return
False
elif
idx
.
stop
is
None
:
pass
else
:
return
False
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
# get length of the indexed tensor along the first axis
try
:
length
=
get_scalar_constant_value
(
shape_of
[
node
.
inputs
[
0
]][
0
],
only_process_constants
=
True
)
except
NotScalarConstantError
:
return
False
# get index (which must be a vector by definition)
idx
=
node
.
inputs
[
1
]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if
isinstance
(
idx
,
Constant
):
idx
=
idx
.
value
if
len
(
idx
)
!=
length
:
return
False
if
np
.
any
(
idx
!=
np
.
arange
(
length
)):
return
False
elif
idx
.
owner
is
not
None
and
isinstance
(
idx
.
owner
.
op
,
ARange
):
try
:
start
,
stop
,
step
=
map
(
lambda
x
:
get_scalar_constant_value
(
x
,
only_process_constants
=
True
),
idx
.
owner
.
inputs
,
)
except
NotScalarConstantError
:
return
False
if
start
!=
0
:
return
False
if
stop
!=
length
:
return
False
if
step
!=
1
:
return
False
else
:
return
False
else
:
return
False
# We don't need to copy over any stacktrace here,
# because previous stacktrace should suffice.
return
[
node
.
inputs
[
0
]]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize
(
"fast_compile"
)
@local_optimizer
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx])
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
not
u
.
owner
or
len
(
fgraph
.
clients
[
u
])
>
1
:
return
False
if
isinstance
(
u
.
owner
.
op
,
Elemwise
)
and
len
(
u
.
owner
.
inputs
)
==
1
:
idx
=
node
.
inputs
[
1
:]
x_idx
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
idx
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
x_idx
)
ret
=
u
.
owner
.
op
(
x_idx
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Elemwise
):
new_inputs
=
[]
if
all
([
sum
(
i
.
type
.
broadcastable
)
==
0
for
i
in
u
.
owner
.
inputs
]):
# There is no broadcastable in the inputs
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[
node
.
op
(
i
,
*
idx
)
for
i
in
u
.
owner
.
inputs
]
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
elif
all
(
[
sum
(
i
.
type
.
broadcastable
)
in
[
i
.
ndim
,
0
]
for
i
in
u
.
owner
.
inputs
]
):
# There is no broadcastable in the inputs or it is scalar
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[]
for
i
in
u
.
owner
.
inputs
:
if
sum
(
i
.
type
.
broadcastable
)
==
0
:
new_inputs
.
append
(
node
.
op
(
i
,
*
idx
))
else
:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if
node
.
outputs
[
0
]
.
ndim
==
i
.
ndim
:
new_inputs
.
append
(
i
)
else
:
new_inputs
.
append
(
i
.
dimshuffle
([
"x"
]
*
node
.
outputs
[
0
]
.
ndim
)
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Rebroadcast
):
# make sure that Rebroadcast has only 1 input
assert
len
(
u
.
owner
.
inputs
)
==
1
# Subtensor might reduce dim., adapt broadcast pattern accordingly
new_axis
=
[]
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j
=
0
for
(
i
,
x
)
in
enumerate
(
node
.
op
.
idx_list
):
# if its not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if
isinstance
(
x
,
slice
):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
j
+=
1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for
i
in
range
(
len
(
node
.
op
.
idx_list
),
len
(
u
.
broadcastable
)):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
j
+=
1
subt_x
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
node
.
inputs
[
1
:])
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
subt_x
)
rbcast_subt_x
=
Rebroadcast
(
*
new_axis
)(
subt_x
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
rbcast_subt_x
)
return
[
rbcast_subt_x
]
def
merge_two_slices
(
fgraph
,
slice1
,
len1
,
slice2
,
len2
):
"""
This function merges two slices into a single slice. The code works on
the assumption that:
a) slice1 is actually a slice and not an index, while slice2
can be just an index.
b) the two slices **have been applied consecutively** on the same
tensor
The output slice is **not** in canonical form, but actually just a slice
that can be applied to a tensor to produce the same output as applying
the two consecutive slices.
``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice.
"""
if
not
isinstance
(
slice1
,
slice
):
raise
ValueError
(
(
"First provided slice should actually be of type"
"slice and not an index !"
),
slice1
,
)
sl1
,
reverse1
=
get_canonical_form_slice
(
slice1
,
len1
)
sl2
,
reverse2
=
get_canonical_form_slice
(
slice2
,
len2
)
if
not
isinstance
(
sl2
,
slice
):
if
reverse1
is
None
:
# The first slice is not in reverse, which makes things a lot
# more clear.
# In this case we need to take care only of the special cases:
# len2 <=0 -> throw index error regardless of sl2
# sl2 > len2 -> throw index error
# sl2 < -len2 -> throw index error
# To get a index error we simply use len1+1 to indicate we are
# out of bounds, because passing this index through the formula
# of getting the mixed slice is not guaranteed to result in an
# index error. The **issue though** if that the error will
# complain about accessing element len1+1 which is probably not
# too intuitive for the user
val
=
sl1
.
start
+
sl2
*
sl1
.
step
val
=
switch
(
le
(
len2
,
0
),
len1
+
1
,
val
)
val
=
switch
(
ge
(
sl2
,
len2
),
len1
+
1
,
val
)
val
=
switch
(
lt
(
sl2
,
0
),
-
len1
-
1
,
val
)
if
sl1
.
step
:
val
=
switch
(
eq
(
sl1
.
step
,
0
),
len1
+
1
,
val
)
return
val
else
:
# We are in the more complex case when we do not actually know
# if the first slice was in reverse or not.
# in case it was not in reverse:
p_val
=
sl1
.
start
+
sl2
*
sl1
.
step
# case it was in reverse we need to realize that we do not want
# the k-th element from sl.start but the k-th element from
# sl.stop backwards
n_val
=
sl1
.
stop
-
1
-
sl2
*
sl1
.
step
# we need to pick either n_val or p_val and then follow same
# steps as above for covering the index error cases
val
=
switch
(
lt
(
reverse1
,
0
),
n_val
,
p_val
)
val
=
switch
(
le
(
len2
,
0
),
len1
+
1
,
val
)
val
=
switch
(
ge
(
sl2
,
len2
),
len1
+
1
,
val
)
val
=
switch
(
lt
(
sl2
,
0
),
-
len1
-
1
,
val
)
if
sl1
.
step
:
val
=
switch
(
eq
(
sl1
.
step
,
0
),
len1
+
1
,
val
)
return
val
else
:
# We are deleaing with two slices that need to be put together
# according to the two steps we have 4 different combinations of
# positive/negative. I will denote the case I'm looking at by
# suffixes to the variables (nn,np,pn,pp):
flen
=
sl2
.
stop
-
sl2
.
start
p_step
=
sl1
.
step
*
sl2
.
step
n_step
=
sl1
.
step
*
sl2
.
step
*
-
1
pp_start
=
minimum
(
sl1
.
start
+
sl2
.
start
*
sl1
.
step
,
sl1
.
stop
)
pp_stop
=
minimum
(
sl1
.
start
+
sl2
.
stop
*
sl1
.
step
,
sl1
.
stop
)
pn_stop
=
sl1
.
start
+
(
sl2
.
start
-
1
)
*
sl1
.
step
pn_stop
=
switch
(
and_
(
lt
(
pn_stop
,
0
),
gt
(
flen
,
0
)),
-
len1
-
1
,
minimum
(
pn_stop
,
sl1
.
stop
),
)
pn_start
=
sl1
.
start
+
(
sl2
.
stop
-
1
)
*
sl1
.
step
pn_start
=
minimum
(
pn_start
,
sl1
.
stop
)
pn_start
=
maximum
(
pn_start
,
0
)
np_stop
=
sl1
.
stop
-
sl2
.
stop
*
sl1
.
step
-
1
np_stop
=
switch
(
and_
(
lt
(
np_stop
,
0
),
gt
(
flen
,
0
)),
-
len1
-
1
,
maximum
(
sl1
.
start
-
1
,
np_stop
),
)
np_start
=
maximum
(
sl1
.
start
,
sl1
.
stop
-
sl2
.
start
*
sl1
.
step
-
1
)
nn_start
=
maximum
(
sl1
.
start
,
(
sl1
.
stop
-
1
)
-
(
sl2
.
stop
-
1
)
*
sl1
.
step
)
nn_stop
=
maximum
(
sl1
.
start
,
sl1
.
stop
-
sl2
.
start
*
sl1
.
step
)
start
=
switch
(
lt
(
reverse2
*
reverse1
,
0
),
switch
(
lt
(
reverse1
,
0
),
np_start
,
pn_start
),
switch
(
lt
(
reverse1
,
0
),
nn_start
,
pp_start
),
)
stop
=
switch
(
lt
(
reverse2
*
reverse1
,
0
),
switch
(
lt
(
reverse1
,
0
),
np_stop
,
pn_stop
),
switch
(
lt
(
reverse1
,
0
),
nn_stop
,
pp_stop
),
)
step
=
switch
(
lt
(
reverse2
*
reverse1
,
0
),
n_step
,
p_step
)
start
=
switch
(
le
(
flen
,
0
),
0
,
start
)
stop
=
switch
(
le
(
flen
,
0
),
0
,
stop
)
return
slice
(
start
,
stop
,
step
)
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_subtensor_merge
(
fgraph
,
node
):
"""
Refactored optimization to deal with all cases of tensor merging.
Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
expresses all slices in a canonical form, and then merges them together.
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
u
.
owner
and
isinstance
(
u
.
owner
.
op
,
Subtensor
):
# We can merge :)
# x actual tensor on which we are picking slices
x
=
u
.
owner
.
inputs
[
0
]
# slices of the first applied subtensor
slices1
=
get_idx_list
(
u
.
owner
.
inputs
,
u
.
owner
.
op
.
idx_list
)
slices2
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
# Get the shapes of the vectors !
try
:
# try not to introduce new shape into the graph
xshape
=
fgraph
.
shape_feature
.
shape_of
[
x
]
ushape
=
fgraph
.
shape_feature
.
shape_of
[
u
]
except
AttributeError
:
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xshape
=
x
.
shape
ushape
=
u
.
shape
merged_slices
=
[]
pos_2
=
0
pos_1
=
0
while
(
pos_1
<
len
(
slices1
))
and
(
pos_2
<
len
(
slices2
)):
slice1
=
slices1
[
pos_1
]
if
isinstance
(
slice1
,
slice
):
merged_slices
.
append
(
merge_two_slices
(
fgraph
,
slice1
,
xshape
[
pos_1
],
slices2
[
pos_2
],
ushape
[
pos_2
]
)
)
pos_2
+=
1
else
:
merged_slices
.
append
(
slice1
)
pos_1
+=
1
if
pos_2
<
len
(
slices2
):
merged_slices
+=
slices2
[
pos_2
:]
else
:
merged_slices
+=
slices1
[
pos_1
:]
merged_slices
=
tuple
(
as_index_constant
(
s
)
for
s
in
merged_slices
)
subtens
=
Subtensor
(
merged_slices
)
sl_ins
=
Subtensor
.
collapse
(
merged_slices
,
lambda
x
:
isinstance
(
x
,
Variable
)
)
# Do not call make_node for test_value
out
=
subtens
(
x
,
*
sl_ins
)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
orig_out
=
node
.
outputs
[
0
]
copy_stack_trace
([
orig_out
,
node
.
inputs
[
0
]],
out
)
# Restore original broadcastable dimensions that `subtens()` may
# have been unable to infer again
if
out
.
type
!=
orig_out
.
type
:
assert
out
.
dtype
==
orig_out
.
dtype
assert
out
.
ndim
==
orig_out
.
ndim
out
=
patternbroadcast
(
out
,
orig_out
.
broadcastable
)
copy_stack_trace
([
orig_out
,
node
.
inputs
[
0
]],
out
)
return
[
out
]
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_subtensor_of_alloc
(
fgraph
,
node
):
"""
alloc(val)[x:y] -> alloc(val[...])
alloc(val)[x:y] -> alloc(val)
This can be seen as a lift, but it also reduce the number of computation/memory.
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
:
return
False
if
not
isinstance
(
u
.
owner
.
op
,
Alloc
):
return
False
slices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
val
=
u
.
owner
.
inputs
[
0
]
dims
=
u
.
owner
.
inputs
[
1
:]
assert
len
(
slices
)
<=
len
(
dims
)
# Number of dimensions added to val
n_added_dims
=
u
.
ndim
-
val
.
ndim
# Dimensions of the returned alloc
nw_dims
=
[]
# Slices to take from val
val_slices
=
[]
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if
i
>=
n_added_dims
:
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if
(
val
.
type
.
ndim
>
(
i
-
n_added_dims
)
and
val
.
type
.
broadcastable
[
i
-
n_added_dims
]
):
val_slices
.
append
(
slice
(
None
))
else
:
val_slices
.
append
(
sl
)
csl
,
_
=
get_canonical_form_slice
(
sl
,
dim
)
if
type
(
csl
)
is
not
slice
:
# That dimension is removed.
pass
else
:
nw_dim
=
csl
.
stop
-
csl
.
start
if
csl
.
step
!=
1
:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim
=
ceil_intdiv
(
nw_dim
,
csl
.
step
)
nw_dims
+=
[
nw_dim
]
nw_val
=
val
[
tuple
(
val_slices
)]
nw_dims
+=
dims
[
len
(
slices
)
:]
if
nw_val
.
ndim
>
len
(
nw_dims
):
return
False
rval
=
alloc
(
nw_val
,
*
nw_dims
)
if
type
(
rval
)
not
in
(
list
,
tuple
):
rval
=
[
rval
]
if
rval
[
0
]
.
type
!=
node
.
outputs
[
0
]
.
type
:
# It happen that the make_node() isn't able to infer the same pattern.
# We know it is safe, so fix that.
rval
[
0
]
=
patternbroadcast
(
rval
[
0
],
node
.
outputs
[
0
]
.
broadcastable
)
return
rval
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_subtensor_of_dot
(
fgraph
,
node
):
"""Rewrite ``aet.dot(A, B)[idxs]`` into ``aet.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
if
not
node
.
inputs
[
0
]
.
owner
or
not
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Dot
):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if
len
(
fgraph
.
clients
[
node
.
inputs
[
0
]])
>
1
:
return
a
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
b
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
idx_list
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
num_a_indices
=
min
(
a
.
ndim
-
1
,
len
(
idx_list
))
a_indices
=
idx_list
[:
num_a_indices
]
b_indices
=
idx_list
[
num_a_indices
:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if
b
.
ndim
>
1
and
len
(
b_indices
)
>=
b
.
ndim
-
1
:
b_indices
=
(
b_indices
[:
b
.
ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b
.
ndim
-
2
:]
)
a_sub
=
a
.
__getitem__
(
tuple
(
a_indices
))
b_sub
=
b
.
__getitem__
(
tuple
(
b_indices
))
if
b_indices
else
b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace
(
node
.
outputs
[
0
],
[
a_sub
,
b_sub
])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r
=
dot
(
a_sub
,
b_sub
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
r
)
return
[
r
]
@register_canonicalize
@local_optimizer
([
add
])
def
local_IncSubtensor_serialize
(
fgraph
,
node
):
"""
When using Subtensor, gradient graphs can be ugly.
If we ask for grad(f(a[0]), a), we are going to get something like
IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
This might be ugly, but at least it's as fast as you could want.
If we ask for grad(f(a[0], a[1], a[2]), a), it's much worse...
Elemwise{Add}
IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1])
IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2])
This is much worse because this time we have to produce 3 matrices
the size of 'a', just so we can add them together.
This Op rearranges IncSubtensor's that all work on the same
initial argument (here, Elemwise{second}(a,0)) into a chain. The
advantage of the chain structure is that each one can be optimized
later in the pipeline to operate inplace.
Ideally, the op will do something like this:
#
# add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b,b), c), d)
"""
def
movable
(
i
):
# Return True iff this is a incsubtensor that we can move
return
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor
,
),
)
and
i
.
type
==
o_type
and
len
(
fgraph
.
clients
[
i
])
==
1
and
not
i
.
owner
.
op
.
set_instead_of_inc
)
if
node
.
op
==
add
:
o_type
=
node
.
outputs
[
0
]
.
type
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
if
movable_inputs
:
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
if
len
(
new_inputs
)
==
0
:
new_add
=
new_inputs
[
0
]
else
:
new_add
=
add
(
*
new_inputs
)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace
(
node
.
outputs
[
0
],
new_add
)
# stack up the new incsubtensors
tip
=
new_add
for
mi
in
movable_inputs
:
assert
tip
.
type
==
o_type
assert
tip
.
type
==
mi
.
owner
.
inputs
[
0
]
.
type
tip
=
mi
.
owner
.
op
(
tip
,
*
mi
.
owner
.
inputs
[
1
:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace
(
node
.
outputs
+
mi
.
owner
.
outputs
,
tip
)
return
[
tip
]
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
# We register it in a TopoOptimizer inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 5 passes.
compile
.
optdb
.
register
(
"pre_local_IncSubtensor_serialize"
,
in2out
(
local_IncSubtensor_serialize
),
# Just before canonizer
0.99
,
"fast_run"
,
)
# after priority 50 Destructive inplace operations
# gemm is the first one now, at priority 70
@local_optimizer
([
IncSubtensor
],
inplace
=
True
)
def
local_inplace_setsubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
dta
=
node
.
op
.
destroyhandler_tolerate_aliased
new_op
=
node
.
op
.
__class__
(
node
.
op
.
idx_list
,
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
,
destroyhandler_tolerate_aliased
=
dta
,
)
new_node
=
new_op
(
*
node
.
inputs
)
val
=
getattr
(
node
.
outputs
[
0
]
.
tag
,
"nan_guard_mode_check"
,
True
)
new_node
.
tag
.
nan_guard_mode_check
=
val
# Copy stacktrace from original outputs to new outputs.
# This is sensible, because the new operation is the
# same as the old one, but now with different attributes.
copy_stack_trace
(
node
.
outputs
,
new_node
)
return
[
new_node
]
return
False
compile
.
optdb
.
register
(
"local_inplace_setsubtensor"
,
TopoOptimizer
(
local_inplace_setsubtensor
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
"fast_run"
,
"inplace"
,
)
@local_optimizer
([
AdvancedIncSubtensor1
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor1
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
clone_inplace
()
new_node
=
new_op
(
*
node
.
inputs
)
copy_stack_trace
(
node
.
outputs
,
new_node
)
return
[
new_node
]
return
False
compile
.
optdb
.
register
(
"local_inplace_AdvancedIncSubtensor1"
,
TopoOptimizer
(
local_inplace_AdvancedIncSubtensor1
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
"fast_run"
,
"inplace"
,
)
@local_optimizer
([
AdvancedIncSubtensor
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
type
(
node
.
op
)(
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
)
new_node
=
new_op
(
*
node
.
inputs
)
copy_stack_trace
(
node
.
outputs
,
new_node
)
return
[
new_node
]
return
False
compile
.
optdb
.
register
(
"local_inplace_AdvancedIncSubtensor"
,
TopoOptimizer
(
local_inplace_AdvancedIncSubtensor
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
"fast_run"
,
"inplace"
,
)
# Register old name
@register_canonicalize
(
"local_incsubtensor_of_allocs"
)
@register_stabilize
(
"local_incsubtensor_of_allocs"
)
@local_optimizer
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_incsubtensor_of_zeros
(
fgraph
,
node
):
"""
IncSubtensor(x, zeros, idx) -> x
"""
if
(
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
))
and
not
node
.
op
.
set_instead_of_inc
):
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
try
:
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
if
get_scalar_constant_value
(
y
,
elemwise
=
False
)
==
0
:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return
[
x
]
except
NotScalarConstantError
:
return
@register_canonicalize
@register_specialize
@local_optimizer
([
IncSubtensor
])
def
local_incsubtensor_of_zeros_to_setsubtensor
(
fgraph
,
node
):
"""
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
"""
if
isinstance
(
node
.
op
,
(
IncSubtensor
))
and
not
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
if
isinstance
(
x
,
Constant
)
and
not
np
.
any
(
x
.
data
):
return
[
IncSubtensor
(
node
.
op
.
idx_list
,
node
.
op
.
inplace
,
set_instead_of_inc
=
True
,
destroyhandler_tolerate_aliased
=
node
.
op
.
destroyhandler_tolerate_aliased
,
)(
*
node
.
inputs
)
]
@register_canonicalize
(
"local_setsubtensor_of_allocs"
)
@register_stabilize
(
"local_setsubtensor_of_allocs"
)
@local_optimizer
([
IncSubtensor
])
def
local_setsubtensor_of_constants
(
fgraph
,
node
):
"""
SetSubtensor(x, x[idx], idx) -> x
when x is constant or alloc.
"""
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try
:
replace_x
=
get_scalar_constant_value
(
x
,
elemwise
=
False
)
except
NotScalarConstantError
:
return
try
:
replace_y
=
get_scalar_constant_value
(
y
,
elemwise
=
False
)
except
NotScalarConstantError
:
return
if
replace_x
==
replace_y
:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return
[
x
]
else
:
return
False
@register_canonicalize
@register_specialize
@local_optimizer
([
AdvancedSubtensor1
])
def
local_adv_sub1_adv_inc_sub1
(
fgraph
,
node
):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
Notes
-----
This opt add AssertOp. Otherwise, it would remove shape and
index error. If you want to get rid of them, see the
:ref:`unsafe_optimization` section.
WARNING:
A previous version of this optimization also matched
AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y
This is incorrect when there are duplicate indices.
The current version warns the user about potential past issues.
"""
if
not
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
return
inp
=
node
.
inputs
[
0
]
if
not
inp
.
owner
or
not
isinstance
(
inp
.
owner
.
op
,
AdvancedIncSubtensor1
):
return
idx
=
node
.
inputs
[
1
]
idx2
=
inp
.
owner
.
inputs
[
2
]
x
=
inp
.
owner
.
inputs
[
0
]
y
=
inp
.
owner
.
inputs
[
1
]
if
idx
is
not
idx2
:
return
if
(
not
inp
.
owner
.
op
.
set_instead_of_inc
and
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
extract_constant
(
x
,
elemwise
=
False
)
!=
0
):
return
if
not
inp
.
owner
.
op
.
set_instead_of_inc
:
return
cond
=
[
aet_all
(
and_
(
lt
(
idx
,
x
.
shape
[
0
]),
ge
(
idx
,
-
x
.
shape
[
0
])))]
if
not
fgraph
.
shape_feature
.
same_shape
(
idx
,
y
,
0
,
0
):
cond
.
append
(
eq
(
idx
.
shape
[
0
],
y
.
shape
[
0
]))
r
=
Assert
(
"Bad indexing or shapes in a AdvancedIncSubtensor1 "
"that was optimized away"
)(
y
,
*
cond
)
copy_stack_trace
(
y
,
r
)
if
r
.
dtype
==
node
.
outputs
[
0
]
.
dtype
:
return
[
r
]
# It is possible that y is upcast or downcast to x.dtype.
# In all case, as we set or add with 0, we can just cast y.
r2
=
cast
(
r
,
node
.
outputs
[
0
]
.
dtype
)
# Copy over stacktrace from before casting, since
# we don't expect problems in the casting operation,
# and any problems in the indexing would have been spotted above.
copy_stack_trace
(
r
,
r2
)
return
[
r2
]
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless
@local_optimizer
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_useless_inc_subtensor_alloc
(
fgraph
,
node
):
"""
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
a fully or partially broadcastable variable, by one that skips the
intermediate `alloc` where possible.
"""
if
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
)):
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
i
=
node
.
inputs
[
2
:]
if
y
.
owner
is
not
None
and
isinstance
(
y
.
owner
.
op
,
Alloc
):
# `z` is the input of the Alloc op, i.e. aet.alloc(z, <shape>)
z
=
y
.
owner
.
inputs
[
0
]
try
:
shape_feature
=
fgraph
.
shape_feature
except
AttributeError
:
# The shape feature may not be available in some mode, but we
# need it for this optimization, so don't continue.
return
False
shape_of
=
shape_feature
.
shape_of
same_shape
=
shape_feature
.
same_shape
# Get the subtensor of `x` indexed by `i` in order to compare
# shapes later.
if
isinstance
(
node
.
op
,
IncSubtensor
):
xi
=
Subtensor
(
node
.
op
.
idx_list
)(
x
,
*
i
)
elif
isinstance
(
node
.
op
,
AdvancedIncSubtensor
):
xi
=
advanced_subtensor
(
x
,
*
i
)
elif
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
):
xi
=
advanced_subtensor1
(
x
,
*
i
)
else
:
raise
Exception
(
"Should never happen!"
)
reason
=
"local_useless_incsubtensor_alloc"
# Add `xi` to the shape feature `fgraph`. This is important for
# shape inference later because the variable must be part of the
# function graph in order to call `same_shape` on it.
if
xi
not
in
shape_of
:
shape_feature
.
on_import
(
fgraph
,
xi
.
owner
,
f
"{reason}: add `xi`"
)
# `xi` may have more dimensions than `y` since the subtensor ops
# do automatic broadcasting of the increment internally. Thus, we
# need to make the leading implicitly broadcasted dimensions
# explicit for shape comparison later.
if
xi
.
ndim
>
y
.
ndim
:
y
=
shape_padleft
(
y
,
xi
.
ndim
-
y
.
ndim
)
if
y
not
in
shape_of
:
shape_feature
.
on_import
(
fgraph
,
y
.
owner
,
f
"{reason}: add `y`"
)
# Build `z_broad` explicitly to include extra implicit dimensions.
z_broad
=
(
True
,)
*
(
xi
.
ndim
-
z
.
ndim
)
+
z
.
broadcastable
cond
=
[
# The shapes of `y` and `xi` must either agree or `y` may
# also have shape equal to 1 which may be treated as a
# broadcastable dimension by the subtensor op.
or_
(
eq
(
y
.
shape
[
k
],
1
),
eq
(
y
.
shape
[
k
],
xi
.
shape
[
k
]))
# Loop over all dimensions.
for
k
in
range
(
xi
.
ndim
)
# We need to check the above shapes, if
# * the pre-alloc increment `z` is broadcastable in
# dimension `k` (if it isn't, then the shapes of `z` and
# `y` are the same by the definition of the `Alloc` op in
# this dimension and replacing `y` by `z` will not hide a
# shape error), and
# * `xi` and `y` do not have the same shape in dimension
# `k` or we cannot infer the shape statically (if the
# shapes of `xi` and `y` are not the same, then replacing
# `y` by `z` will hide the shape error of `y`), and
# * the shape of `y` is not equal to 1 or we cannot infer
# the shape statically (if the shape of `y` is equal to
# 1, then `y` is broadcasted by the inc_subtensor op
# internally, so the shapes of `xi` and `y` do not need
# to match in dimension `k`; else we need to check at
# runtime that the shape of `y` is either 1 or the same
# as `xi` or otherwise replacing `y` by `z` will hide a
# shape error).
if
(
z_broad
[
k
]
and
not
same_shape
(
xi
,
y
,
dim_x
=
k
,
dim_y
=
k
)
and
shape_of
[
y
][
k
]
!=
1
)
]
if
len
(
cond
)
>
0
:
msg
=
"`x[i]` and `y` do not have the same shape."
z
=
Assert
(
msg
)(
z
,
*
cond
)
r
=
node
.
op
(
x
,
z
,
*
i
)
# Copy over stacktrace from previous output, since
# we don't expect problems when removing the intermediate
# alloc operation and so we still want to point at the line
# of the inc_subtensor operation.
copy_stack_trace
(
node
.
outputs
,
r
)
return
[
r
]
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
...
...
aesara/tensor/subtensor_opt.py
浏览文件 @
13ebc731
import
sys
import
numpy
as
np
import
aesara
import
aesara
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimizer
import
aesara.scalar.basic
as
aes
from
aesara.tensor.basic_opt
import
register_specialize
from
aesara
import
compile
from
aesara.tensor.shape
import
shape_tuple
from
aesara.assert_op
import
Assert
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
in2out
,
local_optimizer
from
aesara.tensor.basic
import
(
Alloc
,
ARange
,
Rebroadcast
,
ScalarFromTensor
,
TensorFromScalar
,
alloc
,
cast
,
extract_constant
,
get_scalar_constant_value
,
make_vector
,
patternbroadcast
,
switch
,
)
from
aesara.tensor.basic_opt
import
(
register_canonicalize
,
register_specialize
,
register_stabilize
,
)
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
add
from
aesara.tensor.math
import
all
as
aet_all
from
aesara.tensor.math
import
(
and_
,
ceil_intdiv
,
dot
,
eq
,
ge
,
gt
,
le
,
lt
,
maximum
,
minimum
,
or_
,
)
from
aesara.tensor.shape
import
shape_padleft
,
shape_tuple
from
aesara.tensor.sharedvar
import
TensorSharedVariable
from
aesara.tensor.sharedvar
import
TensorSharedVariable
from
aesara.tensor.subtensor
import
(
from
aesara.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedSubtensor
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
IncSubtensor
,
Subtensor
,
advanced_inc_subtensor1
,
advanced_subtensor
,
advanced_subtensor1
,
advanced_subtensor1
,
as_index_constant
,
get_canonical_form_slice
,
get_idx_list
,
inc_subtensor
,
inc_subtensor
,
)
)
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type_other
import
NoneTypeT
,
SliceConstant
,
SliceType
from
aesara.tensor.type_other
import
NoneTypeT
,
SliceConstant
,
SliceType
from
aesara.tensor.var
import
TensorConstant
,
TensorVariable
from
aesara.tensor.var
import
TensorConstant
,
TensorVariable
def
register_useless
(
lopt
,
*
tags
,
**
kwargs
):
if
type
(
lopt
)
==
str
:
def
register
(
inner_lopt
):
return
register_useless
(
inner_lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register
else
:
name
=
kwargs
.
pop
(
"name"
,
None
)
or
lopt
.
__name__
compile
.
mode
.
local_useless
.
register
(
name
,
lopt
,
"last"
,
"fast_run"
,
*
tags
,
**
kwargs
)
return
lopt
def
transform_take
(
a
,
indices
,
axis
):
def
transform_take
(
a
,
indices
,
axis
):
r"""Transform ``arr[:,:,:,indices,...]``-like operations into single-dimensional, vector index operations.
r"""Transform ``arr[:,:,:,indices,...]``-like operations into single-dimensional, vector index operations.
...
@@ -181,3 +250,1312 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
...
@@ -181,3 +250,1312 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
)
)
copy_stack_trace
(
node
.
outputs
[
0
],
new_res
)
copy_stack_trace
(
node
.
outputs
[
0
],
new_res
)
return
[
new_res
]
return
[
new_res
]
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_subtensor_of_dot
(
fgraph
,
node
):
"""Rewrite ``aet.dot(A, B)[idxs]`` into ``aet.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
if
not
node
.
inputs
[
0
]
.
owner
or
not
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Dot
):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if
len
(
fgraph
.
clients
[
node
.
inputs
[
0
]])
>
1
:
return
a
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
b
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
idx_list
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
num_a_indices
=
min
(
a
.
ndim
-
1
,
len
(
idx_list
))
a_indices
=
idx_list
[:
num_a_indices
]
b_indices
=
idx_list
[
num_a_indices
:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if
b
.
ndim
>
1
and
len
(
b_indices
)
>=
b
.
ndim
-
1
:
b_indices
=
(
b_indices
[:
b
.
ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b
.
ndim
-
2
:]
)
a_sub
=
a
.
__getitem__
(
tuple
(
a_indices
))
b_sub
=
b
.
__getitem__
(
tuple
(
b_indices
))
if
b_indices
else
b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace
(
node
.
outputs
[
0
],
[
a_sub
,
b_sub
])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r
=
dot
(
a_sub
,
b_sub
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
r
)
return
[
r
]
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_useless_slice
(
fgraph
,
node
):
"""
Remove Subtensor of the form X[0, :] -> X[0]
"""
if
isinstance
(
node
.
op
,
Subtensor
):
slices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
last_slice
=
len
(
slices
)
for
s
in
slices
[::
-
1
]:
# check if slice and then check slice indices
if
(
isinstance
(
s
,
slice
)
and
s
.
start
is
None
and
s
.
stop
is
None
and
(
s
.
step
is
None
or
extract_constant
(
s
.
step
,
only_process_constants
=
True
)
==
1
)
):
last_slice
-=
1
else
:
break
# check if we removed something
if
last_slice
<
len
(
slices
):
subtens
=
Subtensor
(
slices
[:
last_slice
])
sl_ins
=
Subtensor
.
collapse
(
slices
[:
last_slice
],
lambda
x
:
isinstance
(
x
,
Variable
)
)
out
=
subtens
(
node
.
inputs
[
0
],
*
sl_ins
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
out
)
return
[
out
]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize
(
"fast_compile"
)
@local_optimizer
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx])
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
not
u
.
owner
or
len
(
fgraph
.
clients
[
u
])
>
1
:
return
False
if
isinstance
(
u
.
owner
.
op
,
Elemwise
)
and
len
(
u
.
owner
.
inputs
)
==
1
:
idx
=
node
.
inputs
[
1
:]
x_idx
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
idx
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
x_idx
)
ret
=
u
.
owner
.
op
(
x_idx
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Elemwise
):
new_inputs
=
[]
if
all
([
sum
(
i
.
type
.
broadcastable
)
==
0
for
i
in
u
.
owner
.
inputs
]):
# There is no broadcastable in the inputs
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[
node
.
op
(
i
,
*
idx
)
for
i
in
u
.
owner
.
inputs
]
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
elif
all
(
[
sum
(
i
.
type
.
broadcastable
)
in
[
i
.
ndim
,
0
]
for
i
in
u
.
owner
.
inputs
]
):
# There is no broadcastable in the inputs or it is scalar
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[]
for
i
in
u
.
owner
.
inputs
:
if
sum
(
i
.
type
.
broadcastable
)
==
0
:
new_inputs
.
append
(
node
.
op
(
i
,
*
idx
))
else
:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if
node
.
outputs
[
0
]
.
ndim
==
i
.
ndim
:
new_inputs
.
append
(
i
)
else
:
new_inputs
.
append
(
i
.
dimshuffle
([
"x"
]
*
node
.
outputs
[
0
]
.
ndim
)
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Rebroadcast
):
# make sure that Rebroadcast has only 1 input
assert
len
(
u
.
owner
.
inputs
)
==
1
# Subtensor might reduce dim., adapt broadcast pattern accordingly
new_axis
=
[]
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j
=
0
for
(
i
,
x
)
in
enumerate
(
node
.
op
.
idx_list
):
# if its not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if
isinstance
(
x
,
slice
):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
j
+=
1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for
i
in
range
(
len
(
node
.
op
.
idx_list
),
len
(
u
.
broadcastable
)):
new_axis
+=
[(
j
,
u
.
broadcastable
[
i
])]
j
+=
1
subt_x
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
node
.
inputs
[
1
:])
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
subt_x
)
rbcast_subt_x
=
Rebroadcast
(
*
new_axis
)(
subt_x
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
rbcast_subt_x
)
return
[
rbcast_subt_x
]
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_subtensor_merge
(
fgraph
,
node
):
"""
Refactored optimization to deal with all cases of tensor merging.
Given a subgraph of the form Subtensor(Subtensor(u)), the optimization
expresses all slices in a canonical form, and then merges them together.
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
u
.
owner
and
isinstance
(
u
.
owner
.
op
,
Subtensor
):
# We can merge :)
# x actual tensor on which we are picking slices
x
=
u
.
owner
.
inputs
[
0
]
# slices of the first applied subtensor
slices1
=
get_idx_list
(
u
.
owner
.
inputs
,
u
.
owner
.
op
.
idx_list
)
slices2
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
# Get the shapes of the vectors !
try
:
# try not to introduce new shape into the graph
xshape
=
fgraph
.
shape_feature
.
shape_of
[
x
]
ushape
=
fgraph
.
shape_feature
.
shape_of
[
u
]
except
AttributeError
:
# Following the suggested use of shape_feature which should
# consider the case when the compilation mode doesn't
# include the ShapeFeature
xshape
=
x
.
shape
ushape
=
u
.
shape
merged_slices
=
[]
pos_2
=
0
pos_1
=
0
while
(
pos_1
<
len
(
slices1
))
and
(
pos_2
<
len
(
slices2
)):
slice1
=
slices1
[
pos_1
]
if
isinstance
(
slice1
,
slice
):
merged_slices
.
append
(
merge_two_slices
(
fgraph
,
slice1
,
xshape
[
pos_1
],
slices2
[
pos_2
],
ushape
[
pos_2
]
)
)
pos_2
+=
1
else
:
merged_slices
.
append
(
slice1
)
pos_1
+=
1
if
pos_2
<
len
(
slices2
):
merged_slices
+=
slices2
[
pos_2
:]
else
:
merged_slices
+=
slices1
[
pos_1
:]
merged_slices
=
tuple
(
as_index_constant
(
s
)
for
s
in
merged_slices
)
subtens
=
Subtensor
(
merged_slices
)
sl_ins
=
Subtensor
.
collapse
(
merged_slices
,
lambda
x
:
isinstance
(
x
,
Variable
)
)
# Do not call make_node for test_value
out
=
subtens
(
x
,
*
sl_ins
)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
# Why? Because, the merged slicing operation could have failed
# because of either of the two original slicing operations
orig_out
=
node
.
outputs
[
0
]
copy_stack_trace
([
orig_out
,
node
.
inputs
[
0
]],
out
)
# Restore original broadcastable dimensions that `subtens()` may
# have been unable to infer again
if
out
.
type
!=
orig_out
.
type
:
assert
out
.
dtype
==
orig_out
.
dtype
assert
out
.
ndim
==
orig_out
.
ndim
out
=
patternbroadcast
(
out
,
orig_out
.
broadcastable
)
copy_stack_trace
([
orig_out
,
node
.
inputs
[
0
]],
out
)
return
[
out
]
@register_specialize
@register_canonicalize
@local_optimizer
([
Subtensor
])
def
local_subtensor_remove_broadcastable_index
(
fgraph
,
node
):
"""
Remove broadcastable dimension with index 0 or -1
a[:,:,:,0] -> a.dimshuffle(0,1,2), when
a.broadcastable = (False, False, False, True)
a[0,:,-1,:] -> a.dimshuffle(1,3), when
a.broadcastable = (True, False, True, False)
"""
if
isinstance
(
node
.
op
,
Subtensor
):
idx
=
node
.
op
.
idx_list
else
:
return
remove_dim
=
[]
node_inputs_idx
=
1
for
dim
,
elem
in
enumerate
(
idx
):
if
isinstance
(
elem
,
(
aes
.
Scalar
)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index
=
node
.
inputs
[
node_inputs_idx
]
if
type
(
dim_index
)
==
aes
.
ScalarConstant
:
dim_index
=
dim_index
.
value
if
dim_index
in
[
0
,
-
1
]
and
node
.
inputs
[
0
]
.
broadcastable
[
dim
]:
remove_dim
.
append
(
dim
)
node_inputs_idx
+=
1
else
:
return
elif
isinstance
(
elem
,
slice
):
if
elem
!=
slice
(
None
):
return
elif
isinstance
(
elem
,
(
int
,
np
.
integer
)):
if
elem
in
[
0
,
-
1
]
and
node
.
inputs
[
0
]
.
broadcastable
[
dim
]:
remove_dim
.
append
(
dim
)
else
:
raise
TypeError
(
"case not expected"
)
if
len
(
remove_dim
)
==
0
:
return
else
:
all_dim
=
range
(
node
.
inputs
[
0
]
.
ndim
)
remain_dim
=
[
x
for
x
in
all_dim
if
x
not
in
remove_dim
]
return
[
node
.
inputs
[
0
]
.
dimshuffle
(
tuple
(
remain_dim
))]
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
])
def
local_subtensor_of_alloc
(
fgraph
,
node
):
"""
alloc(val)[x:y] -> alloc(val[...])
alloc(val)[x:y] -> alloc(val)
This can be seen as a lift, but it also reduce the number of computation/memory.
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
:
return
False
if
not
isinstance
(
u
.
owner
.
op
,
Alloc
):
return
False
slices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
val
=
u
.
owner
.
inputs
[
0
]
dims
=
u
.
owner
.
inputs
[
1
:]
assert
len
(
slices
)
<=
len
(
dims
)
# Number of dimensions added to val
n_added_dims
=
u
.
ndim
-
val
.
ndim
# Dimensions of the returned alloc
nw_dims
=
[]
# Slices to take from val
val_slices
=
[]
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if
i
>=
n_added_dims
:
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if
(
val
.
type
.
ndim
>
(
i
-
n_added_dims
)
and
val
.
type
.
broadcastable
[
i
-
n_added_dims
]
):
val_slices
.
append
(
slice
(
None
))
else
:
val_slices
.
append
(
sl
)
csl
,
_
=
get_canonical_form_slice
(
sl
,
dim
)
if
type
(
csl
)
is
not
slice
:
# That dimension is removed.
pass
else
:
nw_dim
=
csl
.
stop
-
csl
.
start
if
csl
.
step
!=
1
:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim
=
ceil_intdiv
(
nw_dim
,
csl
.
step
)
nw_dims
+=
[
nw_dim
]
nw_val
=
val
[
tuple
(
val_slices
)]
nw_dims
+=
dims
[
len
(
slices
)
:]
if
nw_val
.
ndim
>
len
(
nw_dims
):
return
False
rval
=
alloc
(
nw_val
,
*
nw_dims
)
if
type
(
rval
)
not
in
(
list
,
tuple
):
rval
=
[
rval
]
if
rval
[
0
]
.
type
!=
node
.
outputs
[
0
]
.
type
:
# It happen that the make_node() isn't able to infer the same pattern.
# We know it is safe, so fix that.
rval
[
0
]
=
patternbroadcast
(
rval
[
0
],
node
.
outputs
[
0
]
.
broadcastable
)
return
rval
@register_specialize
@register_canonicalize
@local_optimizer
([
Subtensor
])
def
local_subtensor_inc_subtensor
(
fgraph
,
node
):
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
"""
if
isinstance
(
node
.
op
,
Subtensor
):
x
=
node
.
inputs
[
0
]
if
not
x
.
owner
or
not
isinstance
(
x
.
owner
.
op
,
IncSubtensor
):
return
if
not
x
.
owner
.
op
.
set_instead_of_inc
:
return
if
x
.
owner
.
inputs
[
2
:]
==
node
.
inputs
[
1
:]
and
tuple
(
x
.
owner
.
op
.
idx_list
)
==
tuple
(
node
.
op
.
idx_list
):
out
=
node
.
outputs
[
0
]
y
=
x
.
owner
.
inputs
[
1
]
# If the dtypes differ, cast y into x.dtype
if
x
.
dtype
!=
y
.
dtype
:
y
=
y
.
astype
(
x
.
dtype
)
if
out
.
type
==
y
.
type
:
# if x[idx] and y have the same type, directly return y
return
[
y
]
else
:
# The difference is related to broadcasting pattern
assert
out
.
broadcastable
!=
y
.
broadcastable
# We have to alloc y to the shape of x[idx]
x_subtensor
=
node
.
op
(
x
.
owner
.
inputs
[
0
],
*
x
.
owner
.
inputs
[
2
:])
return
[
alloc
(
y
,
*
x_subtensor
.
shape
)]
else
:
return
@register_specialize
@register_canonicalize
(
"fast_compile_gpu"
)
@register_useless
@local_optimizer
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
fgraph
,
node
):
"""
Replace all subtensor(make_vector) like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes.
"""
x
=
node
.
inputs
[
0
]
if
not
x
.
owner
or
x
.
owner
.
op
!=
make_vector
:
return
if
isinstance
(
node
.
op
,
Subtensor
):
# This optimization needs ShapeOpt and fgraph.shape_feature
try
:
(
idx
,)
=
node
.
op
.
idx_list
except
Exception
:
# 'how can you have multiple indexes into a shape?'
raise
if
isinstance
(
idx
,
(
aes
.
Scalar
,
TensorType
)):
# The idx is a Scalar, ie a Type. This means the actual index
# is contained in node.inputs[1]
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
==
old_idx
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
else
:
return
if
isinstance
(
idx
,
(
int
,
np
.
integer
)):
# We don't need to copy over any stack traces here
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
# if it is a constant we can do something with it
try
:
v
=
get_scalar_constant_value
(
idx
,
only_process_constants
=
True
)
if
isinstance
(
v
,
np
.
integer
):
# Python 2.4 wants to index only with Python integers
v
=
int
(
v
)
# We don't need to copy over any stack traces here
try
:
ret
=
[
x
.
owner
.
inputs
[
v
]]
except
IndexError
:
raise
NotScalarConstantError
(
"Bad user graph!"
)
return
ret
except
NotScalarConstantError
:
pass
elif
idx
.
ndim
==
1
and
isinstance
(
idx
,
Constant
):
values
=
list
(
map
(
int
,
list
(
idx
.
value
)))
ret
=
make_vector
(
*
[
x
.
owner
.
inputs
[
v
]
for
v
in
values
])
# Copy over stack trace from previous output to new output
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
ret
=
patternbroadcast
(
ret
,
node
.
outputs
[
0
]
.
broadcastable
)
return
[
ret
]
else
:
raise
TypeError
(
"case not expected"
)
elif
isinstance
(
idx
,
slice
):
# it is a slice of ints and/or Variables
# check subtensor to see if it can contain constant variables, and if
# it can, then try to unpack them.
try
:
const_slice
=
node
.
op
.
get_constant_idx
(
node
.
inputs
,
allow_partial
=
False
)[
0
]
ret
=
make_vector
(
*
x
.
owner
.
inputs
[
const_slice
])
# Copy over stack trace from previous outputs to new output
copy_stack_trace
(
node
.
outputs
,
ret
)
ret
=
patternbroadcast
(
ret
,
node
.
outputs
[
0
]
.
broadcastable
)
return
[
ret
]
except
NotScalarConstantError
:
pass
else
:
raise
TypeError
(
"case not expected"
)
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer
([
IncSubtensor
])
def
local_useless_inc_subtensor
(
fgraph
,
node
):
"""
Remove IncSubtensor, when we overwrite the full inputs with the
new value.
"""
if
not
isinstance
(
node
.
op
,
IncSubtensor
):
return
if
node
.
op
.
set_instead_of_inc
is
False
:
# This is an IncSubtensor, so the init value must be zeros
try
:
c
=
get_scalar_constant_value
(
node
.
inputs
[
0
],
only_process_constants
=
True
)
if
c
!=
0
:
return
except
NotScalarConstantError
:
return
if
(
node
.
inputs
[
0
]
.
ndim
!=
node
.
inputs
[
1
]
.
ndim
or
node
.
inputs
[
0
]
.
broadcastable
!=
node
.
inputs
[
1
]
.
broadcastable
):
# FB: I didn't check if this case can happen, but this opt
# don't support it.
return
# We have a SetSubtensor or an IncSubtensor on zeros
# If is this IncSubtensor useful?
# Check that we keep all the original data.
# Put the constant inputs in the slice.
idx_cst
=
get_idx_list
(
node
.
inputs
[
1
:],
node
.
op
.
idx_list
)
if
all
(
isinstance
(
e
,
slice
)
and
e
.
start
is
None
and
e
.
stop
is
None
and
(
e
.
step
is
None
or
extract_constant
(
e
.
step
,
only_process_constants
=
True
)
==
-
1
)
for
e
in
idx_cst
):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same.
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
if
not
fgraph
.
shape_feature
.
same_shape
(
node
.
inputs
[
0
],
node
.
inputs
[
1
]):
return
# There is no reverse, so we don't need a replacement.
if
all
(
e
.
step
is
None
for
e
in
node
.
op
.
idx_list
):
# They are the same shape, so we can remove this IncSubtensor
return
[
node
.
inputs
[
1
]]
ret
=
Subtensor
(
node
.
op
.
idx_list
)(
*
node
.
inputs
[
1
:])
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
@register_canonicalize
@register_specialize
@local_optimizer
([
AdvancedIncSubtensor1
])
def
local_set_to_inc_subtensor
(
fgraph
,
node
):
r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it
did this wouldn't need to also be included in the "specialize" pass.
"""
if
(
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
node
.
op
.
set_instead_of_inc
and
node
.
inputs
[
1
]
.
owner
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
,
Elemwise
)
and
isinstance
(
node
.
inputs
[
1
]
.
owner
.
op
.
scalar_op
,
aes
.
Add
)
):
addn
=
node
.
inputs
[
1
]
.
owner
subn
=
None
other
=
None
if
addn
.
inputs
[
0
]
.
owner
and
isinstance
(
addn
.
inputs
[
0
]
.
owner
.
op
,
AdvancedSubtensor1
):
subn
=
addn
.
inputs
[
0
]
.
owner
other
=
addn
.
inputs
[
1
]
elif
addn
.
inputs
[
1
]
.
owner
and
isinstance
(
addn
.
inputs
[
1
]
.
owner
.
op
,
AdvancedSubtensor1
):
subn
=
addn
.
inputs
[
1
]
.
owner
other
=
addn
.
inputs
[
0
]
else
:
return
if
subn
.
inputs
[
1
]
!=
node
.
inputs
[
2
]
or
subn
.
inputs
[
0
]
!=
node
.
inputs
[
0
]:
return
ret
=
advanced_inc_subtensor1
(
node
.
inputs
[
0
],
other
,
node
.
inputs
[
2
])
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
@register_canonicalize
@register_specialize
@local_optimizer
([
Subtensor
,
AdvancedSubtensor1
])
def
local_useless_subtensor
(
fgraph
,
node
):
"""
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
AdvancedSubtensor1 case, the full input is taken when the indices are
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
list/vector or the ARange op.
"""
# This optimization needs ShapeOpt and fgraph.shape_feature
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
shape_of
=
fgraph
.
shape_feature
.
shape_of
if
isinstance
(
node
.
op
,
Subtensor
):
cdata
=
node
.
op
.
get_constant_idx
(
node
.
inputs
,
allow_partial
=
True
,
only_process_constants
=
True
)
for
pos
,
idx
in
enumerate
(
cdata
):
if
not
isinstance
(
idx
,
slice
):
# If idx is not a slice, this means we remove this dimension
# from the output, so the subtensor is not useless
return
False
if
idx
.
start
is
not
None
and
idx
.
start
!=
0
:
# If the start of the slice is different from 0, or is a
# variable, then we assume the subtensor is not useless
return
False
if
idx
.
step
is
not
None
and
idx
.
step
!=
1
:
# If we are going backwards, or skipping elements, then this
# is not a useless subtensor
return
False
for
pos
,
idx
in
enumerate
(
cdata
):
length_pos
=
shape_of
[
node
.
inputs
[
0
]][
pos
]
if
isinstance
(
idx
.
stop
,
(
int
,
np
.
integer
)):
length_pos_data
=
sys
.
maxsize
try
:
length_pos_data
=
get_scalar_constant_value
(
length_pos
,
only_process_constants
=
True
)
except
NotScalarConstantError
:
pass
if
idx
.
stop
<
length_pos_data
:
return
False
elif
isinstance
(
idx
.
stop
,
Variable
):
length_pos_shape_i
=
idx
.
stop
# length_pos is a tensor variable, but length_pos_shape_i
# is a scalar variable. We try to see if they represent
# the same underlying variable.
if
length_pos_shape_i
.
owner
and
isinstance
(
length_pos_shape_i
.
owner
.
op
,
ScalarFromTensor
):
length_pos_shape_i
=
length_pos_shape_i
.
owner
.
inputs
[
0
]
elif
length_pos
.
owner
and
isinstance
(
length_pos
.
owner
.
op
,
TensorFromScalar
):
length_pos
=
length_pos
.
owner
.
inputs
[
0
]
else
:
# We did not find underlying variables of the same type
return
False
# The type can be different: int32 vs int64. length_pos
# should always be int64 as that is what the shape
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
# as index type.
assert
str
(
length_pos
.
type
.
dtype
)
==
"int64"
assert
str
(
length_pos_shape_i
.
type
.
dtype
)
in
[
"int8"
,
"int16"
,
"int32"
,
"int64"
,
]
# length_pos_shape_i cannot be None
if
length_pos_shape_i
!=
length_pos
:
return
False
elif
idx
.
stop
is
None
:
pass
else
:
return
False
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
# get length of the indexed tensor along the first axis
try
:
length
=
get_scalar_constant_value
(
shape_of
[
node
.
inputs
[
0
]][
0
],
only_process_constants
=
True
)
except
NotScalarConstantError
:
return
False
# get index (which must be a vector by definition)
idx
=
node
.
inputs
[
1
]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization
if
isinstance
(
idx
,
Constant
):
idx
=
idx
.
value
if
len
(
idx
)
!=
length
:
return
False
if
np
.
any
(
idx
!=
np
.
arange
(
length
)):
return
False
elif
idx
.
owner
is
not
None
and
isinstance
(
idx
.
owner
.
op
,
ARange
):
try
:
start
,
stop
,
step
=
map
(
lambda
x
:
get_scalar_constant_value
(
x
,
only_process_constants
=
True
),
idx
.
owner
.
inputs
,
)
except
NotScalarConstantError
:
return
False
if
start
!=
0
:
return
False
if
stop
!=
length
:
return
False
if
step
!=
1
:
return
False
else
:
return
False
else
:
return
False
# We don't need to copy over any stacktrace here,
# because previous stacktrace should suffice.
return
[
node
.
inputs
[
0
]]
def
merge_two_slices
(
fgraph
,
slice1
,
len1
,
slice2
,
len2
):
"""
This function merges two slices into a single slice. The code works on
the assumption that:
a) slice1 is actually a slice and not an index, while slice2
can be just an index.
b) the two slices **have been applied consecutively** on the same
tensor
The output slice is **not** in canonical form, but actually just a slice
that can be applied to a tensor to produce the same output as applying
the two consecutive slices.
``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice.
"""
if
not
isinstance
(
slice1
,
slice
):
raise
ValueError
(
(
"First provided slice should actually be of type"
"slice and not an index !"
),
slice1
,
)
sl1
,
reverse1
=
get_canonical_form_slice
(
slice1
,
len1
)
sl2
,
reverse2
=
get_canonical_form_slice
(
slice2
,
len2
)
if
not
isinstance
(
sl2
,
slice
):
if
reverse1
is
None
:
# The first slice is not in reverse, which makes things a lot
# more clear.
# In this case we need to take care only of the special cases:
# len2 <=0 -> throw index error regardless of sl2
# sl2 > len2 -> throw index error
# sl2 < -len2 -> throw index error
# To get a index error we simply use len1+1 to indicate we are
# out of bounds, because passing this index through the formula
# of getting the mixed slice is not guaranteed to result in an
# index error. The **issue though** if that the error will
# complain about accessing element len1+1 which is probably not
# too intuitive for the user
val
=
sl1
.
start
+
sl2
*
sl1
.
step
val
=
switch
(
le
(
len2
,
0
),
len1
+
1
,
val
)
val
=
switch
(
ge
(
sl2
,
len2
),
len1
+
1
,
val
)
val
=
switch
(
lt
(
sl2
,
0
),
-
len1
-
1
,
val
)
if
sl1
.
step
:
val
=
switch
(
eq
(
sl1
.
step
,
0
),
len1
+
1
,
val
)
return
val
else
:
# We are in the more complex case when we do not actually know
# if the first slice was in reverse or not.
# in case it was not in reverse:
p_val
=
sl1
.
start
+
sl2
*
sl1
.
step
# case it was in reverse we need to realize that we do not want
# the k-th element from sl.start but the k-th element from
# sl.stop backwards
n_val
=
sl1
.
stop
-
1
-
sl2
*
sl1
.
step
# we need to pick either n_val or p_val and then follow same
# steps as above for covering the index error cases
val
=
switch
(
lt
(
reverse1
,
0
),
n_val
,
p_val
)
val
=
switch
(
le
(
len2
,
0
),
len1
+
1
,
val
)
val
=
switch
(
ge
(
sl2
,
len2
),
len1
+
1
,
val
)
val
=
switch
(
lt
(
sl2
,
0
),
-
len1
-
1
,
val
)
if
sl1
.
step
:
val
=
switch
(
eq
(
sl1
.
step
,
0
),
len1
+
1
,
val
)
return
val
else
:
# We are deleaing with two slices that need to be put together
# according to the two steps we have 4 different combinations of
# positive/negative. I will denote the case I'm looking at by
# suffixes to the variables (nn,np,pn,pp):
flen
=
sl2
.
stop
-
sl2
.
start
p_step
=
sl1
.
step
*
sl2
.
step
n_step
=
sl1
.
step
*
sl2
.
step
*
-
1
pp_start
=
minimum
(
sl1
.
start
+
sl2
.
start
*
sl1
.
step
,
sl1
.
stop
)
pp_stop
=
minimum
(
sl1
.
start
+
sl2
.
stop
*
sl1
.
step
,
sl1
.
stop
)
pn_stop
=
sl1
.
start
+
(
sl2
.
start
-
1
)
*
sl1
.
step
pn_stop
=
switch
(
and_
(
lt
(
pn_stop
,
0
),
gt
(
flen
,
0
)),
-
len1
-
1
,
minimum
(
pn_stop
,
sl1
.
stop
),
)
pn_start
=
sl1
.
start
+
(
sl2
.
stop
-
1
)
*
sl1
.
step
pn_start
=
minimum
(
pn_start
,
sl1
.
stop
)
pn_start
=
maximum
(
pn_start
,
0
)
np_stop
=
sl1
.
stop
-
sl2
.
stop
*
sl1
.
step
-
1
np_stop
=
switch
(
and_
(
lt
(
np_stop
,
0
),
gt
(
flen
,
0
)),
-
len1
-
1
,
maximum
(
sl1
.
start
-
1
,
np_stop
),
)
np_start
=
maximum
(
sl1
.
start
,
sl1
.
stop
-
sl2
.
start
*
sl1
.
step
-
1
)
nn_start
=
maximum
(
sl1
.
start
,
(
sl1
.
stop
-
1
)
-
(
sl2
.
stop
-
1
)
*
sl1
.
step
)
nn_stop
=
maximum
(
sl1
.
start
,
sl1
.
stop
-
sl2
.
start
*
sl1
.
step
)
start
=
switch
(
lt
(
reverse2
*
reverse1
,
0
),
switch
(
lt
(
reverse1
,
0
),
np_start
,
pn_start
),
switch
(
lt
(
reverse1
,
0
),
nn_start
,
pp_start
),
)
stop
=
switch
(
lt
(
reverse2
*
reverse1
,
0
),
switch
(
lt
(
reverse1
,
0
),
np_stop
,
pn_stop
),
switch
(
lt
(
reverse1
,
0
),
nn_stop
,
pp_stop
),
)
step
=
switch
(
lt
(
reverse2
*
reverse1
,
0
),
n_step
,
p_step
)
start
=
switch
(
le
(
flen
,
0
),
0
,
start
)
stop
=
switch
(
le
(
flen
,
0
),
0
,
stop
)
return
slice
(
start
,
stop
,
step
)
@register_canonicalize
@local_optimizer
([
add
])
def
local_IncSubtensor_serialize
(
fgraph
,
node
):
"""
When using Subtensor, gradient graphs can be ugly.
If we ask for grad(f(a[0]), a), we are going to get something like
IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
This might be ugly, but at least it's as fast as you could want.
If we ask for grad(f(a[0], a[1], a[2]), a), it's much worse...
Elemwise{Add}
IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0])
IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1])
IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2])
This is much worse because this time we have to produce 3 matrices
the size of 'a', just so we can add them together.
This Op rearranges IncSubtensor's that all work on the same
initial argument (here, Elemwise{second}(a,0)) into a chain. The
advantage of the chain structure is that each one can be optimized
later in the pipeline to operate inplace.
Ideally, the op will do something like this:
#
# add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b,b), c), d)
"""
def
movable
(
i
):
# Return True iff this is a incsubtensor that we can move
return
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor
,
),
)
and
i
.
type
==
o_type
and
len
(
fgraph
.
clients
[
i
])
==
1
and
not
i
.
owner
.
op
.
set_instead_of_inc
)
if
node
.
op
==
add
:
o_type
=
node
.
outputs
[
0
]
.
type
movable_inputs
=
[
i
for
i
in
node
.
inputs
if
movable
(
i
)]
if
movable_inputs
:
new_inputs
=
[
i
for
i
in
node
.
inputs
if
not
movable
(
i
)]
+
[
mi
.
owner
.
inputs
[
0
]
for
mi
in
movable_inputs
]
if
len
(
new_inputs
)
==
0
:
new_add
=
new_inputs
[
0
]
else
:
new_add
=
add
(
*
new_inputs
)
# Copy over stacktrace from original output, as an error
# (e.g. an index error) in this add operation should
# correspond to an error in the original add operation.
copy_stack_trace
(
node
.
outputs
[
0
],
new_add
)
# stack up the new incsubtensors
tip
=
new_add
for
mi
in
movable_inputs
:
assert
tip
.
type
==
o_type
assert
tip
.
type
==
mi
.
owner
.
inputs
[
0
]
.
type
tip
=
mi
.
owner
.
op
(
tip
,
*
mi
.
owner
.
inputs
[
1
:])
# Copy over stacktrace from outputs of the original
# "movable" operation to the new operation.
copy_stack_trace
(
node
.
outputs
+
mi
.
owner
.
outputs
,
tip
)
return
[
tip
]
# print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
# We register it in a TopoOptimizer inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 5 passes.
compile
.
optdb
.
register
(
"pre_local_IncSubtensor_serialize"
,
in2out
(
local_IncSubtensor_serialize
),
# Just before canonizer
0.99
,
"fast_run"
,
)
# after priority 50 Destructive inplace operations
# gemm is the first one now, at priority 70
@local_optimizer
([
IncSubtensor
],
inplace
=
True
)
def
local_inplace_setsubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
dta
=
node
.
op
.
destroyhandler_tolerate_aliased
new_op
=
node
.
op
.
__class__
(
node
.
op
.
idx_list
,
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
,
destroyhandler_tolerate_aliased
=
dta
,
)
new_node
=
new_op
(
*
node
.
inputs
)
val
=
getattr
(
node
.
outputs
[
0
]
.
tag
,
"nan_guard_mode_check"
,
True
)
new_node
.
tag
.
nan_guard_mode_check
=
val
# Copy stacktrace from original outputs to new outputs.
# This is sensible, because the new operation is the
# same as the old one, but now with different attributes.
copy_stack_trace
(
node
.
outputs
,
new_node
)
return
[
new_node
]
return
False
compile
.
optdb
.
register
(
"local_inplace_setsubtensor"
,
TopoOptimizer
(
local_inplace_setsubtensor
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
"fast_run"
,
"inplace"
,
)
@local_optimizer
([
AdvancedIncSubtensor1
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor1
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
clone_inplace
()
new_node
=
new_op
(
*
node
.
inputs
)
copy_stack_trace
(
node
.
outputs
,
new_node
)
return
[
new_node
]
return
False
compile
.
optdb
.
register
(
"local_inplace_AdvancedIncSubtensor1"
,
TopoOptimizer
(
local_inplace_AdvancedIncSubtensor1
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
"fast_run"
,
"inplace"
,
)
@local_optimizer
([
AdvancedIncSubtensor
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
type
(
node
.
op
)(
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
)
new_node
=
new_op
(
*
node
.
inputs
)
copy_stack_trace
(
node
.
outputs
,
new_node
)
return
[
new_node
]
return
False
compile
.
optdb
.
register
(
"local_inplace_AdvancedIncSubtensor"
,
TopoOptimizer
(
local_inplace_AdvancedIncSubtensor
,
failure_callback
=
TopoOptimizer
.
warn_inplace
),
60
,
"fast_run"
,
"inplace"
,
)
# Register old name
@register_canonicalize
(
"local_incsubtensor_of_allocs"
)
@register_stabilize
(
"local_incsubtensor_of_allocs"
)
@local_optimizer
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_incsubtensor_of_zeros
(
fgraph
,
node
):
"""
IncSubtensor(x, zeros, idx) -> x
"""
if
(
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
))
and
not
node
.
op
.
set_instead_of_inc
):
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
try
:
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
if
get_scalar_constant_value
(
y
,
elemwise
=
False
)
==
0
:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return
[
x
]
except
NotScalarConstantError
:
return
@register_canonicalize
@register_specialize
@local_optimizer
([
IncSubtensor
])
def
local_incsubtensor_of_zeros_to_setsubtensor
(
fgraph
,
node
):
"""
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
"""
if
isinstance
(
node
.
op
,
(
IncSubtensor
))
and
not
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
if
isinstance
(
x
,
Constant
)
and
not
np
.
any
(
x
.
data
):
return
[
IncSubtensor
(
node
.
op
.
idx_list
,
node
.
op
.
inplace
,
set_instead_of_inc
=
True
,
destroyhandler_tolerate_aliased
=
node
.
op
.
destroyhandler_tolerate_aliased
,
)(
*
node
.
inputs
)
]
@register_canonicalize
(
"local_setsubtensor_of_allocs"
)
@register_stabilize
(
"local_setsubtensor_of_allocs"
)
@local_optimizer
([
IncSubtensor
])
def
local_setsubtensor_of_constants
(
fgraph
,
node
):
"""
SetSubtensor(x, x[idx], idx) -> x
when x is constant or alloc.
"""
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
node
.
op
.
set_instead_of_inc
:
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try
:
replace_x
=
get_scalar_constant_value
(
x
,
elemwise
=
False
)
except
NotScalarConstantError
:
return
try
:
replace_y
=
get_scalar_constant_value
(
y
,
elemwise
=
False
)
except
NotScalarConstantError
:
return
if
replace_x
==
replace_y
:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return
[
x
]
else
:
return
False
@register_canonicalize
@register_specialize
@local_optimizer
([
AdvancedSubtensor1
])
def
local_adv_sub1_adv_inc_sub1
(
fgraph
,
node
):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y
Notes
-----
This opt add AssertOp. Otherwise, it would remove shape and
index error. If you want to get rid of them, see the
:ref:`unsafe_optimization` section.
WARNING:
A previous version of this optimization also matched
AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y
This is incorrect when there are duplicate indices.
The current version warns the user about potential past issues.
"""
if
not
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
return
inp
=
node
.
inputs
[
0
]
if
not
inp
.
owner
or
not
isinstance
(
inp
.
owner
.
op
,
AdvancedIncSubtensor1
):
return
idx
=
node
.
inputs
[
1
]
idx2
=
inp
.
owner
.
inputs
[
2
]
x
=
inp
.
owner
.
inputs
[
0
]
y
=
inp
.
owner
.
inputs
[
1
]
if
idx
is
not
idx2
:
return
if
(
not
inp
.
owner
.
op
.
set_instead_of_inc
and
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
extract_constant
(
x
,
elemwise
=
False
)
!=
0
):
return
if
not
inp
.
owner
.
op
.
set_instead_of_inc
:
return
cond
=
[
aet_all
(
and_
(
lt
(
idx
,
x
.
shape
[
0
]),
ge
(
idx
,
-
x
.
shape
[
0
])))]
if
not
fgraph
.
shape_feature
.
same_shape
(
idx
,
y
,
0
,
0
):
cond
.
append
(
eq
(
idx
.
shape
[
0
],
y
.
shape
[
0
]))
r
=
Assert
(
"Bad indexing or shapes in a AdvancedIncSubtensor1 "
"that was optimized away"
)(
y
,
*
cond
)
copy_stack_trace
(
y
,
r
)
if
r
.
dtype
==
node
.
outputs
[
0
]
.
dtype
:
return
[
r
]
# It is possible that y is upcast or downcast to x.dtype.
# In all case, as we set or add with 0, we can just cast y.
r2
=
cast
(
r
,
node
.
outputs
[
0
]
.
dtype
)
# Copy over stacktrace from before casting, since
# we don't expect problems in the casting operation,
# and any problems in the indexing would have been spotted above.
copy_stack_trace
(
r
,
r2
)
return
[
r2
]
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless
@local_optimizer
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_useless_inc_subtensor_alloc
(
fgraph
,
node
):
"""
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
a fully or partially broadcastable variable, by one that skips the
intermediate `alloc` where possible.
"""
if
isinstance
(
node
.
op
,
(
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
)):
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
i
=
node
.
inputs
[
2
:]
if
y
.
owner
is
not
None
and
isinstance
(
y
.
owner
.
op
,
Alloc
):
# `z` is the input of the Alloc op, i.e. aet.alloc(z, <shape>)
z
=
y
.
owner
.
inputs
[
0
]
try
:
shape_feature
=
fgraph
.
shape_feature
except
AttributeError
:
# The shape feature may not be available in some mode, but we
# need it for this optimization, so don't continue.
return
False
shape_of
=
shape_feature
.
shape_of
same_shape
=
shape_feature
.
same_shape
# Get the subtensor of `x` indexed by `i` in order to compare
# shapes later.
if
isinstance
(
node
.
op
,
IncSubtensor
):
xi
=
Subtensor
(
node
.
op
.
idx_list
)(
x
,
*
i
)
elif
isinstance
(
node
.
op
,
AdvancedIncSubtensor
):
xi
=
advanced_subtensor
(
x
,
*
i
)
elif
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
):
xi
=
advanced_subtensor1
(
x
,
*
i
)
else
:
raise
Exception
(
"Should never happen!"
)
reason
=
"local_useless_incsubtensor_alloc"
# Add `xi` to the shape feature `fgraph`. This is important for
# shape inference later because the variable must be part of the
# function graph in order to call `same_shape` on it.
if
xi
not
in
shape_of
:
shape_feature
.
on_import
(
fgraph
,
xi
.
owner
,
f
"{reason}: add `xi`"
)
# `xi` may have more dimensions than `y` since the subtensor ops
# do automatic broadcasting of the increment internally. Thus, we
# need to make the leading implicitly broadcasted dimensions
# explicit for shape comparison later.
if
xi
.
ndim
>
y
.
ndim
:
y
=
shape_padleft
(
y
,
xi
.
ndim
-
y
.
ndim
)
if
y
not
in
shape_of
:
shape_feature
.
on_import
(
fgraph
,
y
.
owner
,
f
"{reason}: add `y`"
)
# Build `z_broad` explicitly to include extra implicit dimensions.
z_broad
=
(
True
,)
*
(
xi
.
ndim
-
z
.
ndim
)
+
z
.
broadcastable
cond
=
[
# The shapes of `y` and `xi` must either agree or `y` may
# also have shape equal to 1 which may be treated as a
# broadcastable dimension by the subtensor op.
or_
(
eq
(
y
.
shape
[
k
],
1
),
eq
(
y
.
shape
[
k
],
xi
.
shape
[
k
]))
# Loop over all dimensions.
for
k
in
range
(
xi
.
ndim
)
# We need to check the above shapes, if
# * the pre-alloc increment `z` is broadcastable in
# dimension `k` (if it isn't, then the shapes of `z` and
# `y` are the same by the definition of the `Alloc` op in
# this dimension and replacing `y` by `z` will not hide a
# shape error), and
# * `xi` and `y` do not have the same shape in dimension
# `k` or we cannot infer the shape statically (if the
# shapes of `xi` and `y` are not the same, then replacing
# `y` by `z` will hide the shape error of `y`), and
# * the shape of `y` is not equal to 1 or we cannot infer
# the shape statically (if the shape of `y` is equal to
# 1, then `y` is broadcasted by the inc_subtensor op
# internally, so the shapes of `xi` and `y` do not need
# to match in dimension `k`; else we need to check at
# runtime that the shape of `y` is either 1 or the same
# as `xi` or otherwise replacing `y` by `z` will hide a
# shape error).
if
(
z_broad
[
k
]
and
not
same_shape
(
xi
,
y
,
dim_x
=
k
,
dim_y
=
k
)
and
shape_of
[
y
][
k
]
!=
1
)
]
if
len
(
cond
)
>
0
:
msg
=
"`x[i]` and `y` do not have the same shape."
z
=
Assert
(
msg
)(
z
,
*
cond
)
r
=
node
.
op
(
x
,
z
,
*
i
)
# Copy over stacktrace from previous output, since
# we don't expect problems when removing the intermediate
# alloc operation and so we still want to point at the line
# of the inc_subtensor operation.
copy_stack_trace
(
node
.
outputs
,
r
)
return
[
r
]
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论