Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ccbab653
提交
ccbab653
authored
12月 08, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 09, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move subtensor lift rewrites to their own module
上级
5fa5c9ba
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
920 行增加
和
847 行删除
+920
-847
__init__.py
pytensor/tensor/rewriting/__init__.py
+1
-0
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+0
-381
subtensor_lift.py
pytensor/tensor/rewriting/subtensor_lift.py
+411
-0
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+5
-466
test_subtensor_lift.py
tests/tensor/rewriting/test_subtensor_lift.py
+503
-0
没有找到文件。
pytensor/tensor/rewriting/__init__.py
浏览文件 @
ccbab653
...
@@ -15,4 +15,5 @@ import pytensor.tensor.rewriting.ofg
...
@@ -15,4 +15,5 @@ import pytensor.tensor.rewriting.ofg
import
pytensor.tensor.rewriting.shape
import
pytensor.tensor.rewriting.shape
import
pytensor.tensor.rewriting.special
import
pytensor.tensor.rewriting.special
import
pytensor.tensor.rewriting.subtensor
import
pytensor.tensor.rewriting.subtensor
import
pytensor.tensor.rewriting.subtensor_lift
import
pytensor.tensor.rewriting.uncanonicalize
import
pytensor.tensor.rewriting.uncanonicalize
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
ccbab653
import
itertools
import
itertools
import
sys
import
sys
from
collections.abc
import
Iterable
import
numpy
as
np
import
numpy
as
np
...
@@ -21,11 +20,9 @@ from pytensor.scalar import constant as scalar_constant
...
@@ -21,11 +20,9 @@ from pytensor.scalar import constant as scalar_constant
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
Alloc
,
Alloc
,
Join
,
Join
,
MakeVector
,
ScalarFromTensor
,
ScalarFromTensor
,
TensorFromScalar
,
TensorFromScalar
,
alloc
,
alloc
,
as_tensor
,
cast
,
cast
,
concatenate
,
concatenate
,
get_scalar_constant_value
,
get_scalar_constant_value
,
...
@@ -38,11 +35,8 @@ from pytensor.tensor.blockwise import Blockwise
...
@@ -38,11 +35,8 @@ from pytensor.tensor.blockwise import Blockwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
(
from
pytensor.tensor.math
import
(
Dot
,
add
,
add
,
and_
,
and_
,
ceil_intdiv
,
dot
,
eq
,
eq
,
ge
,
ge
,
gt
,
gt
,
...
@@ -60,11 +54,8 @@ from pytensor.tensor.rewriting.basic import (
...
@@ -60,11 +54,8 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize
,
register_stabilize
,
)
)
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
Shape
,
SpecifyShape
,
shape_padleft
,
shape_padleft
,
shape_tuple
,
shape_tuple
,
specify_shape
,
)
)
from
pytensor.tensor.sharedvar
import
TensorSharedVariable
from
pytensor.tensor.sharedvar
import
TensorSharedVariable
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
...
@@ -78,7 +69,6 @@ from pytensor.tensor.subtensor import (
...
@@ -78,7 +69,6 @@ from pytensor.tensor.subtensor import (
advanced_subtensor
,
advanced_subtensor
,
advanced_subtensor1
,
advanced_subtensor1
,
as_index_constant
,
as_index_constant
,
as_index_literal
,
get_canonical_form_slice
,
get_canonical_form_slice
,
get_constant_idx
,
get_constant_idx
,
get_idx_list
,
get_idx_list
,
...
@@ -277,64 +267,6 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
...
@@ -277,64 +267,6 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
return
[
new_res
]
return
[
new_res
]
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter
([
Subtensor
])
def
local_subtensor_of_dot
(
fgraph
,
node
):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
if
not
(
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Dot
)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if
len
(
fgraph
.
clients
[
node
.
inputs
[
0
]])
>
1
:
return
a
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
b
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
idx_list
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
num_a_indices
=
min
(
a
.
ndim
-
1
,
len
(
idx_list
))
a_indices
=
idx_list
[:
num_a_indices
]
b_indices
=
idx_list
[
num_a_indices
:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if
b
.
ndim
>
1
and
len
(
b_indices
)
>=
b
.
ndim
-
1
:
b_indices
=
(
b_indices
[:
b
.
ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b
.
ndim
-
2
:]
)
a_sub
=
a
.
__getitem__
(
tuple
(
a_indices
))
b_sub
=
b
.
__getitem__
(
tuple
(
b_indices
))
if
b_indices
else
b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace
(
node
.
outputs
[
0
],
[
a_sub
,
b_sub
])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r
=
dot
(
a_sub
,
b_sub
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
r
)
return
[
r
]
@register_infer_shape
@register_infer_shape
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
...
@@ -420,75 +352,6 @@ def local_useless_slice(fgraph, node):
...
@@ -420,75 +352,6 @@ def local_useless_slice(fgraph, node):
return
[
out
]
return
[
out
]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize
(
"fast_compile"
)
@node_rewriter
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
or
len
(
fgraph
.
clients
[
u
])
>
1
:
return
False
if
isinstance
(
u
.
owner
.
op
,
Elemwise
)
and
len
(
u
.
owner
.
inputs
)
==
1
:
idx
=
node
.
inputs
[
1
:]
x_idx
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
idx
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
x_idx
)
ret
=
u
.
owner
.
op
(
x_idx
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Elemwise
):
new_inputs
=
[]
if
all
(
sum
(
i
.
type
.
broadcastable
)
==
0
for
i
in
u
.
owner
.
inputs
):
# There is no broadcastable in the inputs
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[
node
.
op
(
i
,
*
idx
)
for
i
in
u
.
owner
.
inputs
]
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
elif
all
(
sum
(
i
.
type
.
broadcastable
)
in
[
i
.
ndim
,
0
]
for
i
in
u
.
owner
.
inputs
):
# There is no broadcastable in the inputs or it is scalar
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[]
for
i
in
u
.
owner
.
inputs
:
if
sum
(
i
.
type
.
broadcastable
)
==
0
:
new_inputs
.
append
(
node
.
op
(
i
,
*
idx
))
else
:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if
node
.
outputs
[
0
]
.
ndim
==
i
.
ndim
:
new_inputs
.
append
(
i
)
else
:
new_inputs
.
append
(
i
.
dimshuffle
([
"x"
]
*
node
.
outputs
[
0
]
.
ndim
)
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@node_rewriter
([
Subtensor
])
@node_rewriter
([
Subtensor
])
...
@@ -619,76 +482,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
...
@@ -619,76 +482,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return
[
node
.
inputs
[
0
]
.
dimshuffle
(
tuple
(
remain_dim
))]
return
[
node
.
inputs
[
0
]
.
dimshuffle
(
tuple
(
remain_dim
))]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
Subtensor
])
def
local_subtensor_of_alloc
(
fgraph
,
node
):
"""
alloc(val)[x:y] -> alloc(val[...])
alloc(val)[x:y] -> alloc(val)
This can be seen as a lift, but it also reduce the number of computation/memory.
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
:
return
False
if
not
isinstance
(
u
.
owner
.
op
,
Alloc
):
return
False
slices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
val
=
u
.
owner
.
inputs
[
0
]
dims
=
u
.
owner
.
inputs
[
1
:]
assert
len
(
slices
)
<=
len
(
dims
)
# Number of dimensions added to val
n_added_dims
=
u
.
ndim
-
val
.
ndim
# Dimensions of the returned alloc
nw_dims
=
[]
# Slices to take from val
val_slices
=
[]
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
,
strict
=
False
)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if
i
>=
n_added_dims
:
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if
(
val
.
type
.
ndim
>
(
i
-
n_added_dims
)
and
val
.
type
.
broadcastable
[
i
-
n_added_dims
]
):
val_slices
.
append
(
slice
(
None
))
else
:
val_slices
.
append
(
sl
)
csl
,
_
=
get_canonical_form_slice
(
sl
,
dim
)
if
type
(
csl
)
is
not
slice
:
# That dimension is removed.
pass
else
:
nw_dim
=
csl
.
stop
-
csl
.
start
if
csl
.
step
!=
1
:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim
=
ceil_intdiv
(
nw_dim
,
csl
.
step
)
nw_dims
+=
[
nw_dim
]
nw_val
=
val
[
tuple
(
val_slices
)]
nw_dims
+=
dims
[
len
(
slices
)
:]
if
nw_val
.
ndim
>
len
(
nw_dims
):
return
False
rval
=
alloc
(
nw_val
,
*
nw_dims
)
if
not
isinstance
(
rval
,
list
|
tuple
):
rval
=
[
rval
]
return
rval
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@node_rewriter
([
Subtensor
])
@node_rewriter
([
Subtensor
])
...
@@ -728,91 +521,6 @@ def local_subtensor_inc_subtensor(fgraph, node):
...
@@ -728,91 +521,6 @@ def local_subtensor_inc_subtensor(fgraph, node):
return
return
@register_infer_shape
@register_specialize
@register_canonicalize
(
"fast_compile"
)
@register_useless
@node_rewriter
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
fgraph
,
node
):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
Replace all ``Subtensor`` and ``MakeVector`` cases like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like:
[a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes.
.. note:
This optimization implicitly relies on shape optimizations.
TODO: This only applies to a single indexed dimension; we should have
something more general for constant ``*Subtensor*`` graphs (or perhaps
include this kind of work in the constant folding).
"""
if
not
isinstance
(
node
.
op
,
Subtensor
|
AdvancedSubtensor1
):
return
False
x
=
node
.
inputs
[
0
]
if
not
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
MakeVector
)):
return
False
make_vector_op
=
x
.
owner
.
op
if
isinstance
(
node
.
op
,
Subtensor
):
idxs
=
node
.
op
.
idx_list
# Subtensor has no indexes, return make_vector
if
not
idxs
:
return
[
x
]
(
idx
,)
=
idxs
if
isinstance
(
idx
,
ScalarType
|
TensorType
):
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
.
is_super
(
old_idx
)
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
int
|
np
.
integer
):
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
try
:
v
=
get_underlying_scalar_constant_value
(
idx
,
only_process_constants
=
True
)
try
:
ret
=
[
x
.
owner
.
inputs
[
v
]]
except
IndexError
:
raise
NotScalarConstantError
(
"Bad user graph!"
)
return
ret
except
NotScalarConstantError
:
pass
elif
idx
.
ndim
==
1
and
isinstance
(
idx
,
Constant
):
values
=
list
(
map
(
int
,
list
(
idx
.
value
)))
ret
=
make_vector_op
(
*
[
x
.
owner
.
inputs
[
v
]
for
v
in
values
])
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
elif
isinstance
(
idx
,
slice
):
# The index is a slice. If it's a constant slice, we can perform the
# index operation here.
try
:
const_slice
=
get_constant_idx
(
node
.
op
.
idx_list
,
node
.
inputs
,
allow_partial
=
False
)[
0
]
ret
=
make_vector_op
(
*
x
.
owner
.
inputs
[
const_slice
])
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
except
NotScalarConstantError
:
pass
@register_infer_shape
@register_infer_shape
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
...
@@ -1615,95 +1323,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
...
@@ -1615,95 +1323,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
return
[
r
]
return
[
r
]
@register_specialize
@register_canonicalize
@node_rewriter
([
Subtensor
])
def
local_subtensor_shape_constant
(
fgraph
,
node
):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
We want to convert graphs like
Subtensor{int64} [id A] ''
|Shape [id B] ''
| |<TensorType(float64, row)> [id C]
|ScalarConstant{0} [id D]
into
TensorConstant{1}
TODO: Something like `local_shape_to_shape_i` should be a general
canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were
the case, we could change this to only operate on `Shape_i`\s.
Currently, we're not handling them because they should only appear when
`ShapeFeature` is present, and it will also simplify/remove them.
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
shape
=
node
.
inputs
[
0
]
if
not
(
shape
.
owner
and
isinstance
(
shape
.
owner
.
op
,
Shape
)):
return
False
shape_arg
=
shape
.
owner
.
inputs
[
0
]
(
idx
,)
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
try
:
idx_val
=
as_index_literal
(
idx
)
except
NotScalarConstantError
:
return
False
assert
idx_val
!=
np
.
newaxis
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
shape_parts
=
shape_arg
.
type
.
broadcastable
[
idx_val
]
if
isinstance
(
shape_parts
,
Iterable
):
if
all
(
shape_parts
):
return
[
as_tensor
([
1
]
*
len
(
shape_parts
),
dtype
=
np
.
int64
,
ndim
=
1
)]
elif
shape_parts
:
return
[
as_tensor
(
1
,
dtype
=
np
.
int64
)]
@register_canonicalize
@node_rewriter
([
Subtensor
])
def
local_subtensor_SpecifyShape_lift
(
fgraph
,
node
):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
specify_shape_node
=
node
.
inputs
[
0
]
if
not
(
specify_shape_node
.
owner
and
isinstance
(
specify_shape_node
.
owner
.
op
,
SpecifyShape
)
):
return
False
obj_arg
=
specify_shape_node
.
owner
.
inputs
[
0
]
shape_arg
=
specify_shape_node
.
owner
.
inputs
[
1
:]
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
if
any
(
isinstance
(
index
,
slice
)
or
isinstance
(
getattr
(
index
,
"type"
,
None
),
SliceType
)
for
index
in
indices
):
return
False
new_obj_arg
=
obj_arg
[
indices
]
# No need to specify shape for scalar outputs
if
new_obj_arg
.
ndim
==
0
:
return
[
new_obj_arg
]
return
[
specify_shape
(
new_obj_arg
,
shape_arg
[
len
(
indices
)
:])]
@register_specialize
@register_specialize
@node_rewriter
([
Join
])
@node_rewriter
([
Join
])
def
local_join_subtensors
(
fgraph
,
node
):
def
local_join_subtensors
(
fgraph
,
node
):
...
...
pytensor/tensor/rewriting/subtensor_lift.py
0 → 100644
浏览文件 @
ccbab653
from
collections.abc
import
Iterable
import
numpy
as
np
from
pytensor
import
Variable
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
from
pytensor.scalar
import
basic
as
ps
from
pytensor.tensor.basic
import
(
Alloc
,
MakeVector
,
alloc
,
as_tensor
,
get_underlying_scalar_constant_value
,
register_infer_shape
,
)
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
Dot
,
ceil_intdiv
,
dot
from
pytensor.tensor.rewriting.basic
import
(
register_canonicalize
,
register_specialize
,
register_stabilize
,
)
from
pytensor.tensor.rewriting.subtensor
import
register_useless
from
pytensor.tensor.shape
import
(
Shape
,
SpecifyShape
,
specify_shape
,
)
from
pytensor.tensor.subtensor
import
(
AdvancedSubtensor1
,
Subtensor
,
as_index_literal
,
get_canonical_form_slice
,
get_constant_idx
,
get_idx_list
,
)
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
SliceType
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter
([
Subtensor
])
def
local_subtensor_of_dot
(
fgraph
,
node
):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
the remaining entries of ``idxs`` (if any), modified to skip the
second-to-last dimension of ``B`` (because dot sums over this dimension).
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
if
not
(
node
.
inputs
[
0
]
.
owner
and
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Dot
)):
return
# If there is other node that use the outputs of the dot
# We don't want to compute twice the sub part.
if
len
(
fgraph
.
clients
[
node
.
inputs
[
0
]])
>
1
:
return
a
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
b
=
node
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
idx_list
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
num_a_indices
=
min
(
a
.
ndim
-
1
,
len
(
idx_list
))
a_indices
=
idx_list
[:
num_a_indices
]
b_indices
=
idx_list
[
num_a_indices
:]
# This is necessary because np.dot sums the last index of a with the second to last of b
# so we want to skip the second-to-last index into b.
# This wasn't necessary for a, because we just omitted the last index.
# We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
# (dot also handles b.ndim < 2 as a special case)
if
b
.
ndim
>
1
and
len
(
b_indices
)
>=
b
.
ndim
-
1
:
b_indices
=
(
b_indices
[:
b
.
ndim
-
2
]
+
(
slice
(
None
,
None
,
None
),)
+
b_indices
[
b
.
ndim
-
2
:]
)
a_sub
=
a
.
__getitem__
(
tuple
(
a_indices
))
b_sub
=
b
.
__getitem__
(
tuple
(
b_indices
))
if
b_indices
else
b
# Copy over previous output stacktrace to a_sub and b_sub,
# because an error in the subtensor operation (e.g. an index error)
# on either a or b must correspond to an error in the
# subtensor operation on their dot product.
copy_stack_trace
(
node
.
outputs
[
0
],
[
a_sub
,
b_sub
])
# Copy over previous output stacktrace and previous dot product stacktrace,
# because an error here may correspond to an either in either the original
# dot product, or in the dot product after the subtensor operation.
r
=
dot
(
a_sub
,
b_sub
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
r
)
return
[
r
]
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize
(
"fast_compile"
)
@node_rewriter
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
or
len
(
fgraph
.
clients
[
u
])
>
1
:
return
False
if
isinstance
(
u
.
owner
.
op
,
Elemwise
)
and
len
(
u
.
owner
.
inputs
)
==
1
:
idx
=
node
.
inputs
[
1
:]
x_idx
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
idx
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
x_idx
)
ret
=
u
.
owner
.
op
(
x_idx
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Elemwise
):
new_inputs
=
[]
if
all
(
sum
(
i
.
type
.
broadcastable
)
==
0
for
i
in
u
.
owner
.
inputs
):
# There is no broadcastable in the inputs
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[
node
.
op
(
i
,
*
idx
)
for
i
in
u
.
owner
.
inputs
]
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
elif
all
(
sum
(
i
.
type
.
broadcastable
)
in
[
i
.
ndim
,
0
]
for
i
in
u
.
owner
.
inputs
):
# There is no broadcastable in the inputs or it is scalar
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[]
for
i
in
u
.
owner
.
inputs
:
if
sum
(
i
.
type
.
broadcastable
)
==
0
:
new_inputs
.
append
(
node
.
op
(
i
,
*
idx
))
else
:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if
node
.
outputs
[
0
]
.
ndim
==
i
.
ndim
:
new_inputs
.
append
(
i
)
else
:
new_inputs
.
append
(
i
.
dimshuffle
([
"x"
]
*
node
.
outputs
[
0
]
.
ndim
)
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
@register_infer_shape
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
Subtensor
])
def
local_subtensor_of_alloc
(
fgraph
,
node
):
"""
alloc(val)[x:y] -> alloc(val[...])
alloc(val)[x:y] -> alloc(val)
This can be seen as a lift, but it also reduce the number of computation/memory.
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
:
return
False
if
not
isinstance
(
u
.
owner
.
op
,
Alloc
):
return
False
slices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
val
=
u
.
owner
.
inputs
[
0
]
dims
=
u
.
owner
.
inputs
[
1
:]
assert
len
(
slices
)
<=
len
(
dims
)
# Number of dimensions added to val
n_added_dims
=
u
.
ndim
-
val
.
ndim
# Dimensions of the returned alloc
nw_dims
=
[]
# Slices to take from val
val_slices
=
[]
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
,
strict
=
False
)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if
i
>=
n_added_dims
:
# We check that the corresponding val dimensions was
# not a broadcasted dimensions.
if
(
val
.
type
.
ndim
>
(
i
-
n_added_dims
)
and
val
.
type
.
broadcastable
[
i
-
n_added_dims
]
):
val_slices
.
append
(
slice
(
None
))
else
:
val_slices
.
append
(
sl
)
csl
,
_
=
get_canonical_form_slice
(
sl
,
dim
)
if
type
(
csl
)
is
not
slice
:
# That dimension is removed.
pass
else
:
nw_dim
=
csl
.
stop
-
csl
.
start
if
csl
.
step
!=
1
:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim
=
ceil_intdiv
(
nw_dim
,
csl
.
step
)
nw_dims
+=
[
nw_dim
]
nw_val
=
val
[
tuple
(
val_slices
)]
nw_dims
+=
dims
[
len
(
slices
)
:]
if
nw_val
.
ndim
>
len
(
nw_dims
):
return
False
rval
=
alloc
(
nw_val
,
*
nw_dims
)
if
not
isinstance
(
rval
,
list
|
tuple
):
rval
=
[
rval
]
return
rval
@register_canonicalize
@node_rewriter
([
Subtensor
])
def
local_subtensor_SpecifyShape_lift
(
fgraph
,
node
):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
specify_shape_node
=
node
.
inputs
[
0
]
if
not
(
specify_shape_node
.
owner
and
isinstance
(
specify_shape_node
.
owner
.
op
,
SpecifyShape
)
):
return
False
obj_arg
=
specify_shape_node
.
owner
.
inputs
[
0
]
shape_arg
=
specify_shape_node
.
owner
.
inputs
[
1
:]
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
if
any
(
isinstance
(
index
,
slice
)
or
isinstance
(
getattr
(
index
,
"type"
,
None
),
SliceType
)
for
index
in
indices
):
return
False
new_obj_arg
=
obj_arg
[
indices
]
# No need to specify shape for scalar outputs
if
new_obj_arg
.
ndim
==
0
:
return
[
new_obj_arg
]
return
[
specify_shape
(
new_obj_arg
,
shape_arg
[
len
(
indices
)
:])]
@register_infer_shape
@register_specialize
@register_canonicalize
(
"fast_compile"
)
@register_useless
@node_rewriter
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
fgraph
,
node
):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
Replace all ``Subtensor`` and ``MakeVector`` cases like:
[a,b,c][0] -> a
[a,b,c][0:2] -> [a,b]
Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like:
[a,b,c][[0,2]] -> [a,c]
We can do this for constant indexes.
.. note:
This optimization implicitly relies on shape optimizations.
TODO: This only applies to a single indexed dimension; we should have
something more general for constant ``*Subtensor*`` graphs (or perhaps
include this kind of work in the constant folding).
"""
if
not
isinstance
(
node
.
op
,
Subtensor
|
AdvancedSubtensor1
):
return
False
x
=
node
.
inputs
[
0
]
if
not
(
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
MakeVector
)):
return
False
make_vector_op
=
x
.
owner
.
op
if
isinstance
(
node
.
op
,
Subtensor
):
idxs
=
node
.
op
.
idx_list
# Subtensor has no indexes, return make_vector
if
not
idxs
:
return
[
x
]
(
idx
,)
=
idxs
if
isinstance
(
idx
,
ps
.
ScalarType
|
TensorType
):
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
.
is_super
(
old_idx
)
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
int
|
np
.
integer
):
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
try
:
v
=
get_underlying_scalar_constant_value
(
idx
,
only_process_constants
=
True
)
try
:
ret
=
[
x
.
owner
.
inputs
[
v
]]
except
IndexError
:
raise
NotScalarConstantError
(
"Bad user graph!"
)
return
ret
except
NotScalarConstantError
:
pass
elif
idx
.
ndim
==
1
and
isinstance
(
idx
,
Constant
):
values
=
list
(
map
(
int
,
list
(
idx
.
value
)))
ret
=
make_vector_op
(
*
[
x
.
owner
.
inputs
[
v
]
for
v
in
values
])
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
elif
isinstance
(
idx
,
slice
):
# The index is a slice. If it's a constant slice, we can perform the
# index operation here.
try
:
const_slice
=
get_constant_idx
(
node
.
op
.
idx_list
,
node
.
inputs
,
allow_partial
=
False
)[
0
]
ret
=
make_vector_op
(
*
x
.
owner
.
inputs
[
const_slice
])
copy_stack_trace
(
node
.
outputs
,
ret
)
return
[
ret
]
except
NotScalarConstantError
:
pass
@register_specialize
@register_canonicalize
@node_rewriter
([
Subtensor
])
def
local_subtensor_shape_constant
(
fgraph
,
node
):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
We want to convert graphs like
Subtensor{int64} [id A] ''
|Shape [id B] ''
| |<TensorType(float64, row)> [id C]
|ScalarConstant{0} [id D]
into
TensorConstant{1}
TODO: Something like `local_shape_to_shape_i` should be a general
canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were
the case, we could change this to only operate on `Shape_i`\s.
Currently, we're not handling them because they should only appear when
`ShapeFeature` is present, and it will also simplify/remove them.
"""
if
not
isinstance
(
node
.
op
,
Subtensor
):
return
False
shape
=
node
.
inputs
[
0
]
if
not
(
shape
.
owner
and
isinstance
(
shape
.
owner
.
op
,
Shape
)):
return
False
shape_arg
=
shape
.
owner
.
inputs
[
0
]
(
idx
,)
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
try
:
idx_val
=
as_index_literal
(
idx
)
except
NotScalarConstantError
:
return
False
assert
idx_val
!=
np
.
newaxis
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
shape_parts
=
shape_arg
.
type
.
broadcastable
[
idx_val
]
if
isinstance
(
shape_parts
,
Iterable
):
if
all
(
shape_parts
):
return
[
as_tensor
([
1
]
*
len
(
shape_parts
),
dtype
=
np
.
int64
,
ndim
=
1
)]
elif
shape_parts
:
return
[
as_tensor
(
1
,
dtype
=
np
.
int64
)]
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
ccbab653
...
@@ -9,27 +9,19 @@ from pytensor.compile.function import function
...
@@ -9,27 +9,19 @@ from pytensor.compile.function import function
from
pytensor.compile.mode
import
Mode
,
get_default_mode
,
get_mode
from
pytensor.compile.mode
import
Mode
,
get_default_mode
,
get_mode
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
FunctionGraph
,
vectorize_graph
from
pytensor.graph
import
vectorize_graph
from
pytensor.graph.basic
import
Constant
,
Variable
,
ancestors
,
equal_computations
from
pytensor.graph.basic
import
Constant
,
Variable
,
ancestors
,
equal_computations
from
pytensor.graph.rewriting.basic
import
check_stack_trace
from
pytensor.graph.rewriting.basic
import
check_stack_trace
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.rewriting.utils
import
rewrite_graph
from
pytensor.graph.type
import
Type
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
from
pytensor.tensor
import
inplace
from
pytensor.tensor.basic
import
Alloc
,
_convert_to_int8
from
pytensor.tensor.basic
import
Alloc
,
MakeVector
,
_convert_to_int8
,
make_vector
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.math
import
Dot
,
add
,
dot
,
exp
,
sqr
from
pytensor.tensor.math
import
Dot
,
dot
,
exp
,
sqr
from
pytensor.tensor.rewriting.subtensor
import
(
from
pytensor.tensor.rewriting.subtensor
import
(
local_replace_AdvancedSubtensor
,
local_replace_AdvancedSubtensor
,
local_subtensor_make_vector
,
local_subtensor_shape_constant
,
)
)
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
SpecifyShape
,
SpecifyShape
,
_shape
,
shape
,
specify_shape
,
specify_shape
,
)
)
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
...
@@ -49,10 +41,7 @@ from pytensor.tensor.type import (
...
@@ -49,10 +41,7 @@ from pytensor.tensor.type import (
dmatrix
,
dmatrix
,
fmatrix
,
fmatrix
,
iscalar
,
iscalar
,
iscalars
,
ivector
,
ivector
,
lscalar
,
lscalars
,
matrix
,
matrix
,
scalar
,
scalar
,
tensor
,
tensor
,
...
@@ -60,7 +49,7 @@ from pytensor.tensor.type import (
...
@@ -60,7 +49,7 @@ from pytensor.tensor.type import (
tensor4
,
tensor4
,
vector
,
vector
,
)
)
from
pytensor.tensor.type_other
import
make_slice
,
slicetype
from
pytensor.tensor.type_other
import
make_slice
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.unittest_tools
import
create_pytensor_param
from
tests.unittest_tools
import
create_pytensor_param
...
@@ -664,262 +653,6 @@ class TestSubtensorIncSubtensor:
...
@@ -664,262 +653,6 @@ class TestSubtensorIncSubtensor:
assert
np
.
array_equal
(
f
(
x_
,
i_
,
v_
),
v_
.
astype
(
"int8"
))
assert
np
.
array_equal
(
f
(
x_
,
i_
,
v_
),
v_
.
astype
(
"int8"
))
class
TestLocalSubtensorMakeVector
:
mode
=
get_mode
(
"FAST_RUN"
)
.
including
(
"local_subtensor_make_vector"
)
def
test_scalar_idx
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[
0
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
DeepCopyOp
)
assert
f
(
0
,
1
,
2
)
==
0
def
test_idx_symbolic
(
self
):
x
,
y
,
z
=
iscalars
(
"xyz"
)
v
=
MakeVector
(
"int32"
)(
x
,
y
,
z
)
idx
=
pt
.
as_tensor
([
0
],
dtype
=
np
.
int64
)
f
=
function
([
x
,
y
,
z
],
v
[
idx
],
mode
=
self
.
mode
)
opt_fgraph
=
f
.
maker
.
fgraph
assert
opt_fgraph
.
outputs
[
0
]
.
dtype
==
"int32"
assert
isinstance
(
opt_fgraph
.
outputs
[
0
]
.
owner
.
op
,
MakeVector
)
assert
f
(
0
,
1
,
2
)
==
np
.
array
([
0
],
dtype
=
np
.
int32
)
def
test_slice_idx_start
(
self
):
x
,
y
,
z
=
iscalars
(
"xyz"
)
v
=
MakeVector
(
"int32"
)(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[
1
:],
mode
=
self
.
mode
,
on_unused_input
=
"ignore"
)
opt_fgraph
=
f
.
maker
.
fgraph
assert
opt_fgraph
.
outputs
[
0
]
.
dtype
==
"int32"
assert
isinstance
(
opt_fgraph
.
outputs
[
0
]
.
owner
.
op
,
MakeVector
)
assert
len
(
opt_fgraph
.
outputs
[
0
]
.
owner
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
1
and
r
[
1
]
==
2
def
test_slice_idx_stop
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[:
2
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
1
def
test_slice_idx_step
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[::
2
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
2
def
test_AdvancedSubtensor1_idx
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[[
0
,
2
]],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
2
def
test_MakeVector_idx
(
self
):
x
,
y
,
z
,
q
=
lscalars
(
"xyzq"
)
v
=
make_vector
(
x
,
y
,
z
)
q
=
make_vector
(
0
,
2
)
f
=
function
([
x
,
y
,
z
],
v
[
q
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
2
def
test_stack_trace
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
mode
=
get_default_mode
()
.
including
(
"local_subtensor_make_vector"
)
# list of subtensor cases, where local_subtensor_make_vector
# inserts a new MakeVector node
v_subtensors
=
[
v
[:
2
],
v
[::
2
],
v
[[
0
,
2
]]]
for
v_subtensor
in
v_subtensors
:
f
=
function
([
x
,
y
,
z
],
v_subtensor
,
mode
=
mode
)
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
def
test_empty_subtensor
(
self
):
x
,
y
=
lscalars
(
"xy"
)
v
=
make_vector
(
x
,
y
)
out
=
v
[()]
fgraph
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
False
)
node
=
fgraph
.
outputs
[
0
]
.
owner
assert
isinstance
(
node
.
op
,
Subtensor
)
assert
local_subtensor_make_vector
.
transform
(
fgraph
,
node
)
==
[
v
]
class
TestLocalSubtensorLift
:
def
test_basic
(
self
):
# basic test that the Op works
x
=
matrix
(
"x"
)
f
=
function
([
x
],
exp
(
x
)[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
# first subtensor
assert
prog
[
1
]
.
op
==
exp
assert
len
(
prog
)
==
2
f
([[
0
,
1
],
[
2
,
3
]])
# let debugmode test something
def
test_basic_1
(
self
):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
x
=
matrix
(
"x"
)
f
=
function
([
x
],
[
exp
(
x
)[
0
],
exp
(
x
)],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
,
Elemwise
])
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
prog
[
0
]
.
op
==
exp
assert
isinstance
(
prog
[
1
]
.
op
,
Subtensor
)
# first subtensor
assert
isinstance
(
prog
[
2
]
.
op
,
DeepCopyOp
)
assert
len
(
prog
)
==
3
f
([[
0
,
1
],
[
2
,
3
]])
# let debugmode test something
def
test_basic_2
(
self
):
# basic test that the optimization work with scalar broadcasted
x
=
matrix
(
"x"
)
y
=
scalar
(
"y"
)
z
=
matrix
(
"z"
)
f
=
function
([
x
,
y
,
z
],
exp
(
x
+
y
+
z
)[
0
],
mode
=
mode_opt
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
3
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,add}
assert
len
(
prog
)
==
4
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
])
# let debugmode test something
f
([[
0
,
1
],
[
2
,
3
]],
4
,
[[
4
,
5
],
[
6
,
7
]])
def
test_basic_3
(
self
):
# as 1, but take a slice
x
=
matrix
(
"x"
)
y
=
scalar
(
"y"
)
z
=
matrix
(
"z"
)
f
=
function
([
x
,
y
,
z
],
exp
(
x
+
y
+
z
)[
0
:
2
],
mode
=
mode_opt
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
3
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,add}
assert
len
(
prog
)
==
4
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
])
# let debugmode test something
f
([[
0
,
1
],
[
2
,
3
]],
4
,
[[
4
,
5
],
[
6
,
7
]])
def
test_basic_4
(
self
):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y
=
vector
(
"y"
)
f
=
function
([
y
],
exp
(
y
.
dimshuffle
(
0
,
"x"
))[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
1
]
.
op
,
Subtensor
)
assert
prog
[
2
]
.
op
==
exp
assert
len
(
prog
)
==
3
f
([
4
,
5
])
# let debugmode test something
@utt.assertFailure_fast
def
test_basic_5
(
self
):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
x
=
matrix
(
"x"
)
y
=
vector
(
"y"
)
f
=
function
([
x
,
y
],
exp
(
x
+
y
)[
0
],
mode
=
mode_opt
)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check='all')
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
prog
[
1
]
.
op
==
add
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
# first subtensor
assert
prog
[
3
]
.
op
==
inplace
.
exp_inplace
assert
len
(
prog
)
==
4
f
([[
0
,
1
],
[
2
,
3
]],
[
4
,
5
])
# let debugmode test something
def
test_basic_6
(
self
):
# test that we don't lift when we reuse the output of the
# elemwise for other computation.
x
=
matrix
(
"x"
)
y
=
vector
(
"y"
)
f
=
function
([
x
,
y
],
[
exp
(
x
+
y
)[
0
],
exp
(
x
+
y
)
+
x
],
mode
=
mode_opt
)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check=Subtensor)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
1
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,exp}
# first subtensor
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
len
(
prog
)
==
3
f
([[
0
,
1
],
[
2
,
3
]],
[
4
,
5
])
# let debugmode test something
def
test_basic_7
(
self
):
# basic test that the optimization works with a scalar as input,
# and a scalar as output (no broadcasting of the scalar needed).
# The optimization used to fail and display an ERROR message.
x
=
vector
(
"x"
)
y
=
scalar
(
"y"
)
f
=
function
([
x
,
y
],
exp
(
x
+
y
)[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
Subtensor
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
# Composite{add,exp}
assert
isinstance
(
prog
[
1
]
.
op
.
scalar_op
,
ps
.
Composite
)
assert
len
(
prog
)
==
2
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
class
TestLocalSubtensorMerge
:
class
TestLocalSubtensorMerge
:
def
setup_method
(
self
):
def
setup_method
(
self
):
self
.
x_shapes
=
[(
2
,
2
),
(
5
,
3
),
(
4
,
1
),
(
1
,
2
),
(
0
,
2
),
(
2
,
0
),
(
1
,
0
),
(
0
,
0
)]
self
.
x_shapes
=
[(
2
,
2
),
(
5
,
3
),
(
4
,
1
),
(
1
,
2
),
(
0
,
2
),
(
2
,
0
),
(
1
,
0
),
(
0
,
0
)]
...
@@ -1803,200 +1536,6 @@ def test_local_set_to_inc_subtensor():
...
@@ -1803,200 +1536,6 @@ def test_local_set_to_inc_subtensor():
assert
check_stack_trace
(
f2
,
ops_to_check
=
"all"
)
assert
check_stack_trace
(
f2
,
ops_to_check
=
"all"
)
def
test_local_subtensor_of_alloc
():
# DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape.
for
s
in
[(
3
,
5
),
(
4
,
6
),
(
3
,
8
),
(
4
,
7
),
(
1
,
5
),
(
5
,
1
)]:
x
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
1
if
s
[
0
]
==
1
else
None
,
1
if
s
[
1
]
==
1
else
None
),
)
xval
=
np
.
zeros
(
s
,
dtype
=
config
.
floatX
)
yval
=
np
.
arange
(
s
[
1
],
dtype
=
config
.
floatX
)
for
y
in
[
shared
(
yval
),
pt
.
constant
([
1.0
])]:
# The rows of yx are copies of y
yx
=
pt
.
alloc
(
y
,
x
.
shape
[
0
],
x
.
shape
[
1
])
# Slice of each row
z_mat
=
yx
[:,
3
:]
assert
z_mat
.
ndim
==
2
# Only one column
z_vec
=
yx
[:,
3
]
assert
z_vec
.
ndim
==
1
# results are vector
slicess
=
[]
if
s
[
0
]
!=
1
:
slicess
.
append
((
2
,
slice
(
None
)))
if
s
[
1
]
!=
1
:
slicess
.
append
((
slice
(
None
),
3
))
# results are matrix
slicess
+=
[
(
slice
(
None
),
slice
(
3
,
None
)),
(
slice
(
3
,
None
),),
(
slice
(
3
,
None
),
slice
(
3
,
None
)),
(
slice
(
1
,
3
),
slice
(
None
,
-
1
)),
(
slice
(
None
,
None
,
2
)),
(
slice
(
1
,
None
,
2
)),
]
for
slices
in
slicess
:
z
=
yx
.
__getitem__
(
slices
)
f
=
function
([
x
],
z
)
if
config
.
mode
!=
"FAST_COMPILE"
:
# Subtensor can be in the input of Alloc
assert
not
isinstance
(
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
,
Subtensor
)
val
=
f
(
xval
)
assert
xval
.
__getitem__
(
slices
)
.
shape
==
val
.
shape
def
test_local_subtensor_shape_constant
():
x
=
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
))
.
shape
[
0
]
(
res
,)
=
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
assert
isinstance
(
res
,
Constant
)
assert
res
.
data
==
1
# Make sure it's part of the canonicalizations
res
=
rewrite_graph
(
x
)
assert
isinstance
(
res
,
Constant
)
assert
res
.
data
==
1
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
)))[
lscalar
()]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
)))[
0
:]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
)))[
lscalar
()
:]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
1
)))[
1
:]
(
res
,)
=
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
assert
isinstance
(
res
,
Constant
)
assert
np
.
array_equal
(
res
.
data
,
[
1
])
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
None
,
1
,
1
)))[
1
:]
(
res
,)
=
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
assert
isinstance
(
res
,
Constant
)
assert
np
.
array_equal
(
res
.
data
,
[
1
,
1
])
# A test for a non-`TensorType`
class
MyType
(
Type
):
def
filter
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
x
=
shape
(
Variable
(
MyType
(),
None
,
None
))[
0
]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
@pytest.mark.parametrize
(
"x, s, idx, x_val, s_val"
,
[
(
vector
(),
(
iscalar
(),),
(
1
,),
np
.
array
([
1
,
2
],
dtype
=
config
.
floatX
),
np
.
array
([
2
],
dtype
=
np
.
int64
),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
1
,),
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
config
.
floatX
),
np
.
array
([
2
,
2
],
dtype
=
np
.
int64
),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
0
,),
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
config
.
floatX
),
np
.
array
([
2
,
3
],
dtype
=
np
.
int64
),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
1
,
1
),
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
config
.
floatX
),
np
.
array
([
2
,
3
],
dtype
=
np
.
int64
),
),
(
tensor3
(),
(
iscalar
(),
iscalar
(),
iscalar
()),
(
-
1
,),
np
.
arange
(
2
*
3
*
5
,
dtype
=
config
.
floatX
)
.
reshape
((
2
,
3
,
5
)),
np
.
array
([
2
,
3
,
5
],
dtype
=
np
.
int64
),
),
(
tensor3
(),
(
iscalar
(),
iscalar
(),
iscalar
()),
(
-
1
,
0
),
np
.
arange
(
2
*
3
*
5
,
dtype
=
config
.
floatX
)
.
reshape
((
2
,
3
,
5
)),
np
.
array
([
2
,
3
,
5
],
dtype
=
np
.
int64
),
),
],
)
def
test_local_subtensor_SpecifyShape_lift
(
x
,
s
,
idx
,
x_val
,
s_val
):
y
=
specify_shape
(
x
,
s
)[
idx
]
assert
isinstance
(
y
.
owner
.
inputs
[
0
]
.
owner
.
op
,
SpecifyShape
)
rewrites
=
RewriteDatabaseQuery
(
include
=
[
None
])
no_rewrites_mode
=
Mode
(
optimizer
=
rewrites
)
y_val_fn
=
function
([
x
,
*
s
],
y
,
on_unused_input
=
"ignore"
,
mode
=
no_rewrites_mode
)
y_val
=
y_val_fn
(
*
([
x_val
,
*
s_val
]))
# This optimization should appear in the canonicalizations
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
if
y
.
ndim
==
0
:
# SpecifyShape should be removed altogether
assert
isinstance
(
y_opt
.
owner
.
op
,
Subtensor
)
assert
y_opt
.
owner
.
inputs
[
0
]
is
x
else
:
assert
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
y_opt_fn
=
function
([
x
,
*
s
],
y_opt
,
on_unused_input
=
"ignore"
)
y_opt_val
=
y_opt_fn
(
*
([
x_val
,
*
s_val
]))
assert
np
.
allclose
(
y_val
,
y_opt_val
)
@pytest.mark.parametrize
(
"x, s, idx"
,
[
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
slice
(
1
,
None
),),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
slicetype
(),),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
1
,
0
),
),
],
)
def
test_local_subtensor_SpecifyShape_lift_fail
(
x
,
s
,
idx
):
y
=
specify_shape
(
x
,
s
)[
idx
]
# This optimization should appear in the canonicalizations
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
assert
not
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"axis, slices_fn, expected_nodes"
,
"axis, slices_fn, expected_nodes"
,
[
[
...
...
tests/tensor/rewriting/test_subtensor_lift.py
0 → 100644
浏览文件 @
ccbab653
import
numpy
as
np
import
pytest
import
unittest_tools
as
utt
from
pytensor
import
(
Mode
,
Variable
,
config
,
function
,
shared
,
)
from
pytensor
import
scalar
as
ps
from
pytensor
import
tensor
as
pt
from
pytensor.compile
import
DeepCopyOp
,
get_default_mode
,
get_mode
from
pytensor.graph
import
(
Constant
,
FunctionGraph
,
RewriteDatabaseQuery
,
Type
,
rewrite_graph
,
)
from
pytensor.graph.rewriting.basic
import
check_stack_trace
from
pytensor.tensor
import
(
add
,
exp
,
inplace
,
iscalar
,
iscalars
,
lscalar
,
lscalars
,
matrix
,
scalar
,
shape
,
slicetype
,
specify_shape
,
tensor
,
tensor3
,
vector
,
)
from
pytensor.tensor.basic
import
MakeVector
,
make_vector
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.rewriting.subtensor_lift
import
(
local_subtensor_make_vector
,
local_subtensor_shape_constant
,
)
from
pytensor.tensor.shape
import
SpecifyShape
,
_shape
from
pytensor.tensor.subtensor
import
Subtensor
mode_opt
=
config
.
mode
if
mode_opt
==
"FAST_COMPILE"
:
mode_opt
=
"FAST_RUN"
mode_opt
=
get_mode
(
mode_opt
)
class
TestLocalSubtensorLift
:
def
test_basic
(
self
):
# basic test that the Op works
x
=
matrix
(
"x"
)
f
=
function
([
x
],
exp
(
x
)[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
# first subtensor
assert
prog
[
1
]
.
op
==
exp
assert
len
(
prog
)
==
2
f
([[
0
,
1
],
[
2
,
3
]])
# let debugmode test something
def
test_basic_1
(
self
):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
x
=
matrix
(
"x"
)
f
=
function
([
x
],
[
exp
(
x
)[
0
],
exp
(
x
)],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
,
Elemwise
])
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
prog
[
0
]
.
op
==
exp
assert
isinstance
(
prog
[
1
]
.
op
,
Subtensor
)
# first subtensor
assert
isinstance
(
prog
[
2
]
.
op
,
DeepCopyOp
)
assert
len
(
prog
)
==
3
f
([[
0
,
1
],
[
2
,
3
]])
# let debugmode test something
def
test_basic_2
(
self
):
# basic test that the optimization work with scalar broadcasted
x
=
matrix
(
"x"
)
y
=
scalar
(
"y"
)
z
=
matrix
(
"z"
)
f
=
function
([
x
,
y
,
z
],
exp
(
x
+
y
+
z
)[
0
],
mode
=
mode_opt
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
3
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,add}
assert
len
(
prog
)
==
4
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
])
# let debugmode test something
f
([[
0
,
1
],
[
2
,
3
]],
4
,
[[
4
,
5
],
[
6
,
7
]])
def
test_basic_3
(
self
):
# as 1, but take a slice
x
=
matrix
(
"x"
)
y
=
scalar
(
"y"
)
z
=
matrix
(
"z"
)
f
=
function
([
x
,
y
,
z
],
exp
(
x
+
y
+
z
)[
0
:
2
],
mode
=
mode_opt
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
3
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,add}
assert
len
(
prog
)
==
4
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
])
# let debugmode test something
f
([[
0
,
1
],
[
2
,
3
]],
4
,
[[
4
,
5
],
[
6
,
7
]])
def
test_basic_4
(
self
):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y
=
vector
(
"y"
)
f
=
function
([
y
],
exp
(
y
.
dimshuffle
(
0
,
"x"
))[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
1
]
.
op
,
Subtensor
)
assert
prog
[
2
]
.
op
==
exp
assert
len
(
prog
)
==
3
f
([
4
,
5
])
# let debugmode test something
@utt.assertFailure_fast
def
test_basic_5
(
self
):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
x
=
matrix
(
"x"
)
y
=
vector
(
"y"
)
f
=
function
([
x
,
y
],
exp
(
x
+
y
)[
0
],
mode
=
mode_opt
)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check='all')
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
prog
[
1
]
.
op
==
add
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
# first subtensor
assert
prog
[
3
]
.
op
==
inplace
.
exp_inplace
assert
len
(
prog
)
==
4
f
([[
0
,
1
],
[
2
,
3
]],
[
4
,
5
])
# let debugmode test something
def
test_basic_6
(
self
):
# test that we don't lift when we reuse the output of the
# elemwise for other computation.
x
=
matrix
(
"x"
)
y
=
vector
(
"y"
)
f
=
function
([
x
,
y
],
[
exp
(
x
+
y
)[
0
],
exp
(
x
+
y
)
+
x
],
mode
=
mode_opt
)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check=Subtensor)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
1
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,exp}
# first subtensor
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
len
(
prog
)
==
3
f
([[
0
,
1
],
[
2
,
3
]],
[
4
,
5
])
# let debugmode test something
def
test_basic_7
(
self
):
# basic test that the optimization works with a scalar as input,
# and a scalar as output (no broadcasting of the scalar needed).
# The optimization used to fail and display an ERROR message.
x
=
vector
(
"x"
)
y
=
scalar
(
"y"
)
f
=
function
([
x
,
y
],
exp
(
x
+
y
)[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
Subtensor
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
# Composite{add,exp}
assert
isinstance
(
prog
[
1
]
.
op
.
scalar_op
,
ps
.
Composite
)
assert
len
(
prog
)
==
2
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
def
test_local_subtensor_of_alloc
():
# DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape.
for
s
in
[(
3
,
5
),
(
4
,
6
),
(
3
,
8
),
(
4
,
7
),
(
1
,
5
),
(
5
,
1
)]:
x
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
1
if
s
[
0
]
==
1
else
None
,
1
if
s
[
1
]
==
1
else
None
),
)
xval
=
np
.
zeros
(
s
,
dtype
=
config
.
floatX
)
yval
=
np
.
arange
(
s
[
1
],
dtype
=
config
.
floatX
)
for
y
in
[
shared
(
yval
),
pt
.
constant
([
1.0
])]:
# The rows of yx are copies of y
yx
=
pt
.
alloc
(
y
,
x
.
shape
[
0
],
x
.
shape
[
1
])
# Slice of each row
z_mat
=
yx
[:,
3
:]
assert
z_mat
.
ndim
==
2
# Only one column
z_vec
=
yx
[:,
3
]
assert
z_vec
.
ndim
==
1
# results are vector
slicess
=
[]
if
s
[
0
]
!=
1
:
slicess
.
append
((
2
,
slice
(
None
)))
if
s
[
1
]
!=
1
:
slicess
.
append
((
slice
(
None
),
3
))
# results are matrix
slicess
+=
[
(
slice
(
None
),
slice
(
3
,
None
)),
(
slice
(
3
,
None
),),
(
slice
(
3
,
None
),
slice
(
3
,
None
)),
(
slice
(
1
,
3
),
slice
(
None
,
-
1
)),
(
slice
(
None
,
None
,
2
)),
(
slice
(
1
,
None
,
2
)),
]
for
slices
in
slicess
:
z
=
yx
.
__getitem__
(
slices
)
f
=
function
([
x
],
z
)
if
config
.
mode
!=
"FAST_COMPILE"
:
# Subtensor can be in the input of Alloc
assert
not
isinstance
(
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
,
Subtensor
)
val
=
f
(
xval
)
assert
xval
.
__getitem__
(
slices
)
.
shape
==
val
.
shape
@pytest.mark.parametrize
(
"x, s, idx, x_val, s_val"
,
[
(
vector
(),
(
iscalar
(),),
(
1
,),
np
.
array
([
1
,
2
],
dtype
=
config
.
floatX
),
np
.
array
([
2
],
dtype
=
np
.
int64
),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
1
,),
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
config
.
floatX
),
np
.
array
([
2
,
2
],
dtype
=
np
.
int64
),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
0
,),
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
config
.
floatX
),
np
.
array
([
2
,
3
],
dtype
=
np
.
int64
),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
1
,
1
),
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
config
.
floatX
),
np
.
array
([
2
,
3
],
dtype
=
np
.
int64
),
),
(
tensor3
(),
(
iscalar
(),
iscalar
(),
iscalar
()),
(
-
1
,),
np
.
arange
(
2
*
3
*
5
,
dtype
=
config
.
floatX
)
.
reshape
((
2
,
3
,
5
)),
np
.
array
([
2
,
3
,
5
],
dtype
=
np
.
int64
),
),
(
tensor3
(),
(
iscalar
(),
iscalar
(),
iscalar
()),
(
-
1
,
0
),
np
.
arange
(
2
*
3
*
5
,
dtype
=
config
.
floatX
)
.
reshape
((
2
,
3
,
5
)),
np
.
array
([
2
,
3
,
5
],
dtype
=
np
.
int64
),
),
],
)
def
test_local_subtensor_SpecifyShape_lift
(
x
,
s
,
idx
,
x_val
,
s_val
):
y
=
specify_shape
(
x
,
s
)[
idx
]
assert
isinstance
(
y
.
owner
.
inputs
[
0
]
.
owner
.
op
,
SpecifyShape
)
rewrites
=
RewriteDatabaseQuery
(
include
=
[
None
])
no_rewrites_mode
=
Mode
(
optimizer
=
rewrites
)
y_val_fn
=
function
([
x
,
*
s
],
y
,
on_unused_input
=
"ignore"
,
mode
=
no_rewrites_mode
)
y_val
=
y_val_fn
(
*
([
x_val
,
*
s_val
]))
# This optimization should appear in the canonicalizations
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
if
y
.
ndim
==
0
:
# SpecifyShape should be removed altogether
assert
isinstance
(
y_opt
.
owner
.
op
,
Subtensor
)
assert
y_opt
.
owner
.
inputs
[
0
]
is
x
else
:
assert
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
y_opt_fn
=
function
([
x
,
*
s
],
y_opt
,
on_unused_input
=
"ignore"
)
y_opt_val
=
y_opt_fn
(
*
([
x_val
,
*
s_val
]))
assert
np
.
allclose
(
y_val
,
y_opt_val
)
@pytest.mark.parametrize
(
"x, s, idx"
,
[
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
slice
(
1
,
None
),),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
slicetype
(),),
),
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
1
,
0
),
),
],
)
def
test_local_subtensor_SpecifyShape_lift_fail
(
x
,
s
,
idx
):
y
=
specify_shape
(
x
,
s
)[
idx
]
# This optimization should appear in the canonicalizations
y_opt
=
rewrite_graph
(
y
,
clone
=
False
)
assert
not
isinstance
(
y_opt
.
owner
.
op
,
SpecifyShape
)
class
TestLocalSubtensorMakeVector
:
mode
=
get_mode
(
"FAST_RUN"
)
.
including
(
"local_subtensor_make_vector"
)
def
test_scalar_idx
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[
0
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
DeepCopyOp
)
assert
f
(
0
,
1
,
2
)
==
0
def
test_idx_symbolic
(
self
):
x
,
y
,
z
=
iscalars
(
"xyz"
)
v
=
MakeVector
(
"int32"
)(
x
,
y
,
z
)
idx
=
pt
.
as_tensor
([
0
],
dtype
=
np
.
int64
)
f
=
function
([
x
,
y
,
z
],
v
[
idx
],
mode
=
self
.
mode
)
opt_fgraph
=
f
.
maker
.
fgraph
assert
opt_fgraph
.
outputs
[
0
]
.
dtype
==
"int32"
assert
isinstance
(
opt_fgraph
.
outputs
[
0
]
.
owner
.
op
,
MakeVector
)
assert
f
(
0
,
1
,
2
)
==
np
.
array
([
0
],
dtype
=
np
.
int32
)
def
test_slice_idx_start
(
self
):
x
,
y
,
z
=
iscalars
(
"xyz"
)
v
=
MakeVector
(
"int32"
)(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[
1
:],
mode
=
self
.
mode
,
on_unused_input
=
"ignore"
)
opt_fgraph
=
f
.
maker
.
fgraph
assert
opt_fgraph
.
outputs
[
0
]
.
dtype
==
"int32"
assert
isinstance
(
opt_fgraph
.
outputs
[
0
]
.
owner
.
op
,
MakeVector
)
assert
len
(
opt_fgraph
.
outputs
[
0
]
.
owner
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
1
and
r
[
1
]
==
2
def
test_slice_idx_stop
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[:
2
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
1
def
test_slice_idx_step
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[::
2
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
2
def
test_AdvancedSubtensor1_idx
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
f
=
function
([
x
,
y
,
z
],
v
[[
0
,
2
]],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
2
def
test_MakeVector_idx
(
self
):
x
,
y
,
z
,
q
=
lscalars
(
"xyzq"
)
v
=
make_vector
(
x
,
y
,
z
)
q
=
make_vector
(
0
,
2
)
f
=
function
([
x
,
y
,
z
],
v
[
q
],
mode
=
self
.
mode
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
len
(
prog
)
==
1
assert
isinstance
(
prog
[
0
]
.
op
,
MakeVector
)
assert
len
(
prog
[
0
]
.
inputs
)
==
2
r
=
f
(
0
,
1
,
2
)
assert
r
[
0
]
==
0
and
r
[
1
]
==
2
def
test_stack_trace
(
self
):
x
,
y
,
z
=
lscalars
(
"xyz"
)
v
=
make_vector
(
x
,
y
,
z
)
mode
=
get_default_mode
()
.
including
(
"local_subtensor_make_vector"
)
# list of subtensor cases, where local_subtensor_make_vector
# inserts a new MakeVector node
v_subtensors
=
[
v
[:
2
],
v
[::
2
],
v
[[
0
,
2
]]]
for
v_subtensor
in
v_subtensors
:
f
=
function
([
x
,
y
,
z
],
v_subtensor
,
mode
=
mode
)
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
def
test_empty_subtensor
(
self
):
x
,
y
=
lscalars
(
"xy"
)
v
=
make_vector
(
x
,
y
)
out
=
v
[()]
fgraph
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
False
)
node
=
fgraph
.
outputs
[
0
]
.
owner
assert
isinstance
(
node
.
op
,
Subtensor
)
assert
local_subtensor_make_vector
.
transform
(
fgraph
,
node
)
==
[
v
]
def
test_local_subtensor_shape_constant
():
x
=
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
))
.
shape
[
0
]
(
res
,)
=
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
assert
isinstance
(
res
,
Constant
)
assert
res
.
data
==
1
# Make sure it's part of the canonicalizations
res
=
rewrite_graph
(
x
)
assert
isinstance
(
res
,
Constant
)
assert
res
.
data
==
1
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
)))[
lscalar
()]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
)))[
0
:]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
None
)))[
lscalar
()
:]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
1
,
1
)))[
1
:]
(
res
,)
=
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
assert
isinstance
(
res
,
Constant
)
assert
np
.
array_equal
(
res
.
data
,
[
1
])
x
=
_shape
(
tensor
(
dtype
=
np
.
float64
,
shape
=
(
None
,
1
,
1
)))[
1
:]
(
res
,)
=
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
assert
isinstance
(
res
,
Constant
)
assert
np
.
array_equal
(
res
.
data
,
[
1
,
1
])
# A test for a non-`TensorType`
class
MyType
(
Type
):
def
filter
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
MyType
)
and
other
.
thingy
==
self
.
thingy
x
=
shape
(
Variable
(
MyType
(),
None
,
None
))[
0
]
assert
not
local_subtensor_shape_constant
.
transform
(
None
,
x
.
owner
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论