Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
73d798ab
提交
73d798ab
authored
9月 16, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add general shape inference for all types of non-boolean indexing
上级
b210efbc
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
326 行增加
和
128 行删除
+326
-128
test_subtensor.py
tests/tensor/test_subtensor.py
+115
-15
subtensor.py
theano/tensor/subtensor.py
+211
-113
没有找到文件。
tests/tensor/test_subtensor.py
浏览文件 @
73d798ab
import
logging
import
logging
import
sys
import
sys
import
pytest
import
pytest
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
import
theano.scalar
as
scal
import
theano.scalar
as
scal
import
theano.tensor
as
tensor
import
theano.tensor
as
tensor
from
six
import
StringIO
from
numpy.testing
import
assert_array_equal
from
numpy.testing
import
assert_array_equal
from
theano
import
config
from
six
import
StringIO
from
theano
import
config
,
change_flags
from
theano.compile
import
DeepCopyOp
from
theano.compile
import
DeepCopyOp
from
theano.gof.op
import
get_test_value
from
theano.gof.toolbox
import
is_same_graph
from
theano.gof.toolbox
import
is_same_graph
from
theano.tensor
import
(
from
theano.tensor
import
(
_shared
,
_shared
,
...
@@ -32,6 +38,8 @@ from theano.tensor import (
...
@@ -32,6 +38,8 @@ from theano.tensor import (
)
)
from
theano.tensor.basic
import
DimShuffle
from
theano.tensor.basic
import
DimShuffle
from
theano.tensor.subtensor
import
(
from
theano.tensor.subtensor
import
(
basic_shape
,
indexed_result_shape
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
AdvancedSubtensor
,
AdvancedSubtensor
,
...
@@ -45,7 +53,7 @@ from theano.tensor.subtensor import (
...
@@ -45,7 +53,7 @@ from theano.tensor.subtensor import (
inc_subtensor
,
inc_subtensor
,
set_subtensor
,
set_subtensor
,
)
)
from
theano
import
change_flags
from
theano
.tensor.type_other
import
make_slice
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.tensor.test_basic
import
inplace_func
,
rand
,
randint_ranged
from
tests.tensor.test_basic
import
inplace_func
,
rand
,
randint_ranged
...
@@ -1815,8 +1823,7 @@ class TestAdvancedSubtensor:
...
@@ -1815,8 +1823,7 @@ class TestAdvancedSubtensor:
],
],
),
aval
),
aval
def
test_advanced_indexing
(
self
):
def
test_2d_3d_tensors
(
self
):
# tests advanced indexing in Theano for 2D and 3D tensors
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
a
=
rng
.
uniform
(
size
=
(
3
,
3
))
a
=
rng
.
uniform
(
size
=
(
3
,
3
))
b
=
theano
.
shared
(
a
)
b
=
theano
.
shared
(
a
)
...
@@ -1920,9 +1927,7 @@ class TestAdvancedSubtensor:
...
@@ -1920,9 +1927,7 @@ class TestAdvancedSubtensor:
class
TestInferShape
(
utt
.
InferShapeTester
):
class
TestInferShape
(
utt
.
InferShapeTester
):
@pytest.mark.slow
def
test_IncSubtensor
(
self
):
def
test_infer_shape
(
self
):
# IncSubtensor
admat
=
dmatrix
()
admat
=
dmatrix
()
bdmat
=
dmatrix
()
bdmat
=
dmatrix
()
advec
=
dvector
()
advec
=
dvector
()
...
@@ -2044,7 +2049,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2044,7 +2049,7 @@ class TestInferShape(utt.InferShapeTester):
IncSubtensor
,
IncSubtensor
,
)
)
# AdvancedIncSubtensor1
def
test_AdvancedIncSubtensor1
(
self
):
admat
=
dmatrix
()
admat
=
dmatrix
()
bdmat
=
dmatrix
()
bdmat
=
dmatrix
()
advec
=
dvector
()
advec
=
dvector
()
...
@@ -2074,6 +2079,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2074,6 +2079,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
)
)
adtens4
=
dtensor4
()
bdtens4
=
dtensor4
()
bdtens4
=
dtensor4
()
adtens4_val
=
rand
(
4
,
3
,
2
,
5
)
adtens4_val
=
rand
(
4
,
3
,
2
,
5
)
aivec_val
=
[
2
,
3
]
aivec_val
=
[
2
,
3
]
...
@@ -2152,7 +2158,10 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2152,7 +2158,10 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
)
)
# AdvancedIncSubtensor
def
test_AdvancedIncSubtensor
(
self
):
admat
=
dmatrix
()
advec
=
dvector
()
admat_val
=
rand
(
5
,
4
)
aivec_val
=
[
1
,
3
,
2
]
aivec_val
=
[
1
,
3
,
2
]
bivec_val
=
[
0
,
3
,
3
]
bivec_val
=
[
0
,
3
,
3
]
advec_val
=
[
23
,
24
,
25
]
advec_val
=
[
23
,
24
,
25
]
...
@@ -2163,7 +2172,7 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2163,7 +2172,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
)
)
def
test_
adv_sub
(
self
):
def
test_
AdvancedSubtensor
(
self
):
admat
=
dmatrix
()
admat
=
dmatrix
()
aivec
=
lvector
()
aivec
=
lvector
()
bivec
=
lvector
()
bivec
=
lvector
()
...
@@ -2177,23 +2186,20 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2177,23 +2186,20 @@ class TestInferShape(utt.InferShapeTester):
[
admat_val
,
aivec_val
,
bivec_val
],
[
admat_val
,
aivec_val
,
bivec_val
],
AdvancedSubtensor
,
AdvancedSubtensor
,
)
)
# Test case that aren't implemented, but make sure they do not crash.
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
admat
,
aivec
],
[
admat
,
aivec
],
[
admat
[
aivec
,
1
:
3
]],
[
admat
[
aivec
,
1
:
3
]],
[
admat_val
,
aivec_val
],
[
admat_val
,
aivec_val
],
AdvancedSubtensor
,
AdvancedSubtensor
,
check_topo
=
False
,
)
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
admat
,
aivec
],
[
admat
,
aivec
],
[
admat
[
1
:
3
,
aivec
]],
[
admat
[
1
:
3
,
aivec
]],
[
admat_val
,
aivec_val
],
[
admat_val
,
aivec_val
],
AdvancedSubtensor
,
AdvancedSubtensor
,
check_topo
=
False
,
)
)
def
test_
boolean
(
self
):
def
test_
AdvancedBooleanSubtensor
(
self
):
n
=
dmatrix
()
n
=
dmatrix
()
n_val
=
np
.
arange
(
6
)
.
reshape
((
2
,
3
))
n_val
=
np
.
arange
(
6
)
.
reshape
((
2
,
3
))
...
@@ -2212,3 +2218,97 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2212,3 +2218,97 @@ class TestInferShape(utt.InferShapeTester):
tensor
.
AdvancedBooleanSubtensor
,
tensor
.
AdvancedBooleanSubtensor
,
check_topo
=
False
,
check_topo
=
False
,
)
)
@change_flags
(
compute_test_value
=
"raise"
)
def
test_basic_shape
():
test_shape
=
(
5
,
4
)
test_indices
=
(
make_slice
(
1
,
3
,
None
),)
res
=
basic_shape
(
test_shape
,
test_indices
)
assert
get_test_value
(
res
)
==
(
2
,)
@change_flags
(
compute_test_value
=
"raise"
)
def
test_indexed_result_shape
():
_test_idx
=
np
.
ix_
(
np
.
array
([
True
,
True
]),
np
.
array
([
True
]),
np
.
array
([
True
,
True
]))
test_shape
=
(
5
,
6
,
7
,
8
)
test_array
=
np
.
arange
(
np
.
prod
(
test_shape
))
.
reshape
(
test_shape
)
def
idx_as_tensor
(
x
):
if
isinstance
(
x
,
(
slice
,
type
(
None
))):
return
x
else
:
return
tensor
.
as_tensor
(
x
)
def
bcast_shape_tuple
(
x
):
if
not
hasattr
(
x
,
"shape"
):
return
x
return
tuple
(
s
if
not
bcast
else
1
for
s
,
bcast
in
zip
(
tuple
(
x
.
shape
),
x
.
broadcastable
)
)
def
compare_index_shapes
(
test_array
,
test_idx
):
res
=
indexed_result_shape
(
tensor
.
as_tensor
(
test_array
)
.
shape
,
[
idx_as_tensor
(
i
)
for
i
in
test_idx
]
)
exp_res
=
test_array
[
test_idx
]
.
shape
assert
np
.
array_equal
(
tuple
(
get_test_value
(
r
)
for
r
in
res
),
exp_res
)
# Test shape-only version
res
=
indexed_result_shape
(
tensor
.
as_tensor
(
test_array
)
.
shape
,
[
bcast_shape_tuple
(
idx_as_tensor
(
i
))
for
i
in
test_idx
],
indices_are_shapes
=
True
,
)
exp_res
=
test_array
[
test_idx
]
.
shape
assert
np
.
array_equal
(
tuple
(
get_test_value
(
r
)
for
r
in
res
),
exp_res
)
# Simple basic indices
test_idx
=
(
slice
(
None
,
None
),)
compare_index_shapes
(
test_array
,
test_idx
)
# Advanced indices
test_idx
=
(
2
,)
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
_test_idx
[:
1
]
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
_test_idx
[:
2
]
compare_index_shapes
(
test_array
,
test_idx
)
# A Mix of advanced and basic indices
test_idx
=
_test_idx
[:
2
]
+
(
slice
(
None
,
None
),)
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
(
slice
(
None
,
None
),)
+
_test_idx
[
1
:]
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
(
slice
(
None
,
None
),
None
)
+
_test_idx
[
1
:
2
]
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
(
np
.
array
(
1
),
slice
(
None
,
None
),
None
)
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
(
slice
(
None
,
None
),
None
,
np
.
array
(
1
))
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
_test_idx
[:
1
]
+
(
slice
(
None
,
None
),)
+
_test_idx
[
1
:
2
]
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
(
_test_idx
[:
1
]
+
(
slice
(
None
,
None
),)
+
_test_idx
[
1
:
2
]
+
(
slice
(
None
,
None
),)
)
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
_test_idx
[:
1
]
+
(
None
,)
+
_test_idx
[
1
:
2
]
compare_index_shapes
(
test_array
,
test_idx
)
test_shape
=
(
5
,
4
)
test_array
=
np
.
arange
(
np
.
prod
(
test_shape
))
.
reshape
(
test_shape
)
test_idx
=
([
1
,
3
,
2
],
slice
(
1
,
3
))
compare_index_shapes
(
test_array
,
test_idx
)
test_idx
=
(
slice
(
1
,
3
),
[
1
,
3
,
2
])
compare_index_shapes
(
test_array
,
test_idx
)
theano/tensor/subtensor.py
浏览文件 @
73d798ab
import
sys
import
sys
from
textwrap
import
dedent
import
warnings
import
warnings
import
logging
import
logging
import
numpy
as
np
import
numpy
as
np
from
six
import
integer_types
import
theano
import
theano
from
theano.gradient
import
DisconnectedType
from
textwrap
import
dedent
from
theano
import
gof
from
itertools
import
groupby
,
chain
from
collections.abc
import
Iterable
from
six
import
integer_types
from
theano
import
gof
,
scalar
as
scal
,
config
from
theano.gof
import
Apply
,
hashtype
,
Op
,
Type
,
MethodNotDefined
,
ParamsType
from
theano.gof
import
Apply
,
hashtype
,
Op
,
Type
,
MethodNotDefined
,
ParamsType
from
theano.gradient
import
DisconnectedType
from
theano.printing
import
pprint
from
theano.printing
import
pprint
from
theano
import
scalar
as
scal
from
theano.tensor.basic
import
alloc
from
theano.tensor.basic
import
(
from
theano.tensor.basic
import
(
alloc
,
addbroadcast
,
addbroadcast
,
clip
,
clip
,
get_scalar_constant_value
,
get_scalar_constant_value
,
...
@@ -23,17 +26,12 @@ from theano.tensor.basic import (
...
@@ -23,17 +26,12 @@ from theano.tensor.basic import (
NotScalarConstantError
,
NotScalarConstantError
,
)
)
from
theano.tensor.elemwise
import
DimShuffle
from
theano.tensor.elemwise
import
DimShuffle
from
theano.tensor.inc_code
import
inc_code
from
theano.tensor.extra_ops
import
broadcast_shape
from
theano.tensor.type_other
import
NoneConst
,
SliceType
,
NoneTypeT
,
make_slice
from
theano.tensor.type_other
import
NoneConst
,
SliceType
,
NoneTypeT
,
make_slice
from
theano
import
config
from
theano.compat
import
Iterable
from
.inc_code
import
inc_code
_logger
=
logging
.
getLogger
(
"theano.tensor.subtensor"
)
_logger
=
logging
.
getLogger
(
"theano.tensor.subtensor"
)
# Do a lazy import of the sparse module
sparse_module_ref
=
None
class
AdvancedIndexingError
(
TypeError
):
class
AdvancedIndexingError
(
TypeError
):
"""
"""
...
@@ -53,16 +51,8 @@ class AdvancedBooleanIndexingError(TypeError):
...
@@ -53,16 +51,8 @@ class AdvancedBooleanIndexingError(TypeError):
pass
pass
##########
# Helpful functions to deal with Subtensor and IncSubtensor
##########
def
make_constant
(
args
):
def
make_constant
(
args
):
"""
"""Convert Python literals to Theano constants in Subtensor arguments."""
Convert python litterals to theano constants in subtensor arguments.
"""
def
conv
(
a
):
def
conv
(
a
):
if
a
is
None
:
if
a
is
None
:
...
@@ -72,6 +62,7 @@ def make_constant(args):
...
@@ -72,6 +62,7 @@ def make_constant(args):
elif
isinstance
(
a
,
(
integer_types
,
np
.
integer
)):
elif
isinstance
(
a
,
(
integer_types
,
np
.
integer
)):
return
scal
.
ScalarConstant
(
scal
.
int64
,
a
)
return
scal
.
ScalarConstant
(
scal
.
int64
,
a
)
else
:
else
:
# Use `tensor.scalar_from_tensor`?
return
a
return
a
return
tuple
(
map
(
conv
,
args
))
return
tuple
(
map
(
conv
,
args
))
...
@@ -112,7 +103,8 @@ def get_idx_list(inputs, idx_list, get_count=False):
...
@@ -112,7 +103,8 @@ def get_idx_list(inputs, idx_list, get_count=False):
def
get_canonical_form_slice
(
theslice
,
length
):
def
get_canonical_form_slice
(
theslice
,
length
):
"""
"""Convert slices to canonical form.
Given a slice [start:stop:step] transform it into a canonical form
Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy.
that respects the conventions imposed by python and numpy.
...
@@ -277,6 +269,162 @@ def get_canonical_form_slice(theslice, length):
...
@@ -277,6 +269,162 @@ def get_canonical_form_slice(theslice, length):
return
value
,
1
return
value
,
1
def
range_len
(
slc
):
"""Length of a `range` object.
Adapted from CPython.
"""
from
theano.tensor
import
switch
,
and_
,
lt
,
gt
start
,
stop
,
step
=
make_constant
([
slc
.
start
,
slc
.
stop
,
slc
.
step
])
return
switch
(
and_
(
gt
(
step
,
0
),
lt
(
start
,
stop
)),
1
+
(
stop
-
1
-
start
)
//
step
,
switch
(
and_
(
lt
(
step
,
0
),
gt
(
start
,
stop
)),
1
+
(
start
-
1
-
stop
)
//
(
-
step
),
scal
.
ScalarConstant
(
scal
.
int64
,
0
),
),
)
def
slice_len
(
slc
,
n
):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
"""
# TODO: Do we need to do this or should we expect `slc` to
# already be canonicalized?
canon_slc
,
_
=
get_canonical_form_slice
(
slc
,
n
)
return
range_len
(
canon_slc
)
def
is_basic_idx
(
idx
):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integers is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
"""
return
isinstance
(
idx
,
(
slice
,
type
(
None
)))
or
isinstance
(
getattr
(
idx
,
"type"
,
None
),
(
SliceType
,
NoneTypeT
)
)
def
basic_shape
(
shape
,
indices
):
"""Computes the shape resulting from basic NumPy indexing.
Basic indices are either `slice`s or `None`s.
`Ellipsis` are not supported here; convert them to `slice`s first.
Parameters
----------
shape: Tuple[int]
The shape of the array being indexed
indices: Sequence[Or[slice, NoneType]]
A sequence of basic indices used to index an array.
"""
res_shape
=
()
for
idx
,
n
in
zip
(
indices
,
shape
):
if
isinstance
(
idx
,
slice
):
res_shape
+=
(
slice_len
(
idx
,
n
),)
elif
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
):
if
idx
.
owner
:
idx_inputs
=
idx
.
owner
.
inputs
else
:
idx_inputs
=
(
None
,)
res_shape
+=
(
slice_len
(
slice
(
*
idx_inputs
),
n
),)
elif
idx
is
None
:
res_shape
+=
(
scal
.
ScalarConstant
(
scal
.
int64
,
1
),)
elif
isinstance
(
getattr
(
idx
,
"type"
,
None
),
NoneTypeT
):
res_shape
+=
(
scal
.
ScalarConstant
(
scal
.
int64
,
1
),)
else
:
raise
ValueError
(
"Invalid index type: {}"
.
format
(
idx
))
return
res_shape
def
group_indices
(
indices
):
"""Group indices sequentially by whether or not they're basic or advanced.
Returns
-------
Tuple[Boolean, List[Tuple[Integer, Any]]]
The boolean indicates whether or not the group is a set of basic
indices. The list contains the contiguous set of indices paired with their
corresponding dimension number in the array being indexed.
"""
idx_groups
=
[]
dim_num
=
-
1
for
basic
,
grp_indices
in
groupby
(
indices
,
key
=
is_basic_idx
):
enum_grp_indices
=
[]
for
idx
in
grp_indices
:
# We "zip" the dimension number to each index, which means we can't
# count indices that add new axes
if
(
idx
is
not
None
)
and
not
isinstance
(
getattr
(
idx
,
"type"
,
None
),
NoneTypeT
):
dim_num
+=
1
enum_grp_indices
.
append
((
dim_num
,
idx
))
idx_groups
.
append
((
basic
,
enum_grp_indices
))
return
idx_groups
def
indexed_result_shape
(
array_shape
,
indices
,
indices_are_shapes
=
False
):
"""Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`.
This function uses NumPy's basic and advanced indexing logic. It can also
handle combinations of advanced and basic indices.
Parameters
----------
array_shape: Tuple[Variable]
Shape of the array being indexed.
indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable]]]]
Either the indices themselves or the shapes of each index--depending
on the value of `indices_are_shapes`.
indices_are_shapes: bool (Optional)
Indicates whether or not the `indices` contains shape tuples instead of
the actual index arrays. If you use this approach, make sure that the
broadcastable dimensions are (scalar) constants with the value `1`, or `1`
exactly.
"""
res_shape
=
()
remaining_dims
=
range
(
theano
.
tensor
.
basic
.
get_vector_length
(
array_shape
))
idx_groups
=
group_indices
(
indices
)
if
len
(
idx_groups
)
>
2
or
len
(
idx_groups
)
>
1
and
not
idx_groups
[
0
][
0
]:
# Bring adv. index groups to the front and merge each group
idx_groups
=
sorted
(
idx_groups
,
key
=
lambda
x
:
x
[
0
])
idx_groups
=
groupby
(
chain
.
from_iterable
(
d_idx
for
_
,
d_idx
in
idx_groups
),
key
=
lambda
x
:
is_basic_idx
(
x
[
1
]),
)
for
basic
,
grp_dim_indices
in
idx_groups
:
dim_nums
,
grp_indices
=
zip
(
*
grp_dim_indices
)
remaining_dims
=
tuple
(
dim
for
dim
in
remaining_dims
if
dim
not
in
dim_nums
)
if
basic
:
grp_shapes
=
tuple
(
array_shape
[
dim
]
for
dim
in
dim_nums
)
res_shape
+=
basic_shape
(
grp_shapes
,
grp_indices
)
else
:
res_shape
+=
broadcast_shape
(
*
grp_indices
,
arrays_are_shapes
=
indices_are_shapes
)
res_shape
+=
tuple
(
array_shape
[
dim
]
for
dim
in
remaining_dims
)
return
res_shape
class
Subtensor
(
Op
):
class
Subtensor
(
Op
):
"""
"""
Return a subtensor view.
Return a subtensor view.
...
@@ -1783,14 +1931,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
...
@@ -1783,14 +1931,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
return
gx
return
gx
#########################
# Advanced indexing
#########################
#
# Should reproduce numpy's behaviour, see url:
# docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
class
AdvancedSubtensor1
(
Op
):
class
AdvancedSubtensor1
(
Op
):
"""
"""
Implement x[ilist] where ilist is a vector of integers.
Implement x[ilist] where ilist is a vector of integers.
...
@@ -1858,7 +1998,6 @@ class AdvancedSubtensor1(Op):
...
@@ -1858,7 +1998,6 @@ class AdvancedSubtensor1(Op):
return
rval
return
rval
def
grad
(
self
,
inputs
,
grads
):
def
grad
(
self
,
inputs
,
grads
):
global
sparse_module_ref
x
,
ilist
=
inputs
x
,
ilist
=
inputs
(
gz
,)
=
grads
(
gz
,)
=
grads
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
...
@@ -1868,10 +2007,8 @@ class AdvancedSubtensor1(Op):
...
@@ -1868,10 +2007,8 @@ class AdvancedSubtensor1(Op):
"AdvancedSubtensor1: you can't take the sparse grad"
"AdvancedSubtensor1: you can't take the sparse grad"
" from a tensor with ndim != 2. ndim is "
+
str
(
x
.
type
.
ndim
)
" from a tensor with ndim != 2. ndim is "
+
str
(
x
.
type
.
ndim
)
)
)
if
sparse_module_ref
is
None
:
import
theano.sparse
as
sparse_module_ref
rval1
=
[
sparse_module_ref
.
construct_sparse_from_list
(
x
,
gz
,
ilist
)]
rval1
=
[
theano
.
sparse
.
construct_sparse_from_list
(
x
,
gz
,
ilist
)]
else
:
else
:
if
x
.
dtype
in
theano
.
tensor
.
discrete_dtypes
:
if
x
.
dtype
in
theano
.
tensor
.
discrete_dtypes
:
# The output dtype is the same as x
# The output dtype is the same as x
...
@@ -2203,45 +2340,6 @@ def as_index_variable(idx):
...
@@ -2203,45 +2340,6 @@ def as_index_variable(idx):
return
idx
return
idx
def
adv_index_broadcastable_pattern
(
a
,
idx
):
"""
This function is only used to determine the broadcast pattern for
AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy
the output. From this, we find the output broadcast pattern.
"""
def
replace_slice
(
v
):
if
isinstance
(
v
,
gof
.
Apply
):
if
len
(
v
.
outputs
)
!=
1
:
raise
ValueError
(
"It is ambiguous which output of a multi-output Op has"
" to be fetched."
,
v
,
)
else
:
v
=
v
.
outputs
[
0
]
if
NoneConst
.
equals
(
v
):
return
None
if
isinstance
(
v
.
type
,
SliceType
):
return
slice
(
None
,
None
)
if
v
.
dtype
==
"bool"
:
return
np
.
ones
((
2
,)
*
v
.
ndim
,
v
.
dtype
)
else
:
return
np
.
zeros
((
2
,)
*
v
.
ndim
,
int
)
newidx
=
tuple
(
map
(
replace_slice
,
idx
))
# 2 - True = 1; 2 - False = 2
fakeshape
=
[
2
-
bc
for
bc
in
a
.
broadcastable
]
retshape
=
np
.
empty
(
fakeshape
)[
newidx
]
.
shape
return
tuple
([
dim
==
1
for
dim
in
retshape
])
def
check_advanced_indexing_dimensions
(
input
,
idx_list
):
def
check_advanced_indexing_dimensions
(
input
,
idx_list
):
"""
"""
This function checks if the index list in idx_list is correct.
This function checks if the index list in idx_list is correct.
...
@@ -2288,23 +2386,33 @@ def check_and_reject_bool(args_el):
...
@@ -2288,23 +2386,33 @@ def check_and_reject_bool(args_el):
class
BaseAdvancedSubtensor
(
Op
):
class
BaseAdvancedSubtensor
(
Op
):
"""
"""
Abstract base class for AdvancedSubtensor and AdvancedBooleanSubtensor.
Abstract base class for AdvancedSubtensor and AdvancedBooleanSubtensor.
Implements advanced indexing with boolean masks.
Implements advanced indexing with boolean masks.
Should be used by __getitem__ and __getslice__, as follows:
- AdvancedSubtensor()(self, *args) or
- AdvancedBooleanSubtensor()(self, *args), if args contain advanced indices
"""
"""
# Should be used by __getitem__ and __getslice__, as follows:
# AdvancedSubtensor()(self, *args) or
# AdvancedBooleanSubtensor()(self, *args),
# if args contains and advanced indexing pattern
__props__
=
()
__props__
=
()
def
make_node
(
self
,
x
,
*
index
):
def
make_node
(
self
,
x
,
*
index
):
x
=
theano
.
tensor
.
as_tensor_variable
(
x
)
x
=
theano
.
tensor
.
as_tensor_variable
(
x
)
index
=
tuple
(
map
(
as_index_variable
,
index
))
index
=
tuple
(
map
(
as_index_variable
,
index
))
bcast
=
adv_index_broadcastable_pattern
(
x
,
index
)
# We only want the broadcast information, and we don't need recursive
# `Subtensor` calls, so we create a fake symbolic shape tuple and
# identify the broadcast dimensions from the shape result of this
# entire subtensor operation.
fake_shape
=
tuple
(
theano
.
tensor
.
tensor
(
dtype
=
"int64"
,
broadcastable
=
())
if
not
bcast
else
1
for
bcast
in
x
.
broadcastable
)
bcast
=
[
getattr
(
i
,
"value"
,
i
)
==
1
for
i
in
indexed_result_shape
(
fake_shape
,
index
)
]
return
gof
.
Apply
(
return
gof
.
Apply
(
self
,
self
,
(
x
,)
+
index
,
(
x
,)
+
index
,
...
@@ -2317,8 +2425,26 @@ class BaseAdvancedSubtensor(Op):
...
@@ -2317,8 +2425,26 @@ class BaseAdvancedSubtensor(Op):
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
def
infer_shape
(
self
,
node
,
ishapes
):
def
infer_shape
(
self
,
node
,
ishapes
):
# Default case, we don't know
indices
=
node
.
inputs
[
1
:]
raise
theano
.
tensor
.
basic
.
ShapeError
(
"case not implemented"
)
index_shapes
=
list
(
ishapes
[
1
:])
for
i
,
idx
in
enumerate
(
indices
):
if
(
isinstance
(
idx
,
(
np
.
bool_
,
bool
))
or
getattr
(
idx
,
"dtype"
,
None
)
==
"bool"
):
raise
theano
.
tensor
.
basic
.
ShapeError
(
"Shape inference for boolean indices is not implemented"
)
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
if
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
):
index_shapes
[
i
]
=
idx
res_shape
=
indexed_result_shape
(
ishapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
)
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
return
[[
s
for
s
in
res_shape
]]
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
...
@@ -2348,34 +2474,10 @@ class AdvancedSubtensor(BaseAdvancedSubtensor):
...
@@ -2348,34 +2474,10 @@ class AdvancedSubtensor(BaseAdvancedSubtensor):
"""
"""
# Should be used by __getitem__ and __getslice__, as follows:
# AdvancedSubtensor()(self, *args),
# if args contains and advanced indexing pattern
def
make_node
(
self
,
x
,
*
index
):
def
make_node
(
self
,
x
,
*
index
):
check_and_reject_bool
(
index
)
check_and_reject_bool
(
index
)
return
super
(
AdvancedSubtensor
,
self
)
.
make_node
(
x
,
*
index
)
return
super
(
AdvancedSubtensor
,
self
)
.
make_node
(
x
,
*
index
)
def
infer_shape
(
self
,
node
,
ishapes
):
# Really special case
if
len
(
ishapes
)
==
3
:
xshp
,
ind1shp
,
ind2shp
=
ishapes
if
(
len
(
xshp
)
==
2
and
ind1shp
is
not
None
and
len
(
ind1shp
)
==
1
and
ind2shp
is
not
None
and
len
(
ind2shp
)
==
1
):
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
if
node
.
inputs
[
2
]
.
owner
is
None
:
return
[
ind2shp
]
else
:
return
[
ind1shp
]
return
super
(
AdvancedSubtensor
,
self
)
.
infer_shape
(
node
,
ishapes
)
def
grad
(
self
,
inputs
,
grads
):
def
grad
(
self
,
inputs
,
grads
):
(
gz
,)
=
grads
(
gz
,)
=
grads
x
=
inputs
[
0
]
x
=
inputs
[
0
]
...
@@ -2401,10 +2503,6 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
...
@@ -2401,10 +2503,6 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
"""
"""
# Should be used by __getitem__ and __getslice__, as follows:
# AdvancedBooleanSubtensor()(self, *args),
# if args contains and advanced indexing pattern with boolean masks
def
grad
(
self
,
inputs
,
grads
):
def
grad
(
self
,
inputs
,
grads
):
(
gz
,)
=
grads
(
gz
,)
=
grads
x
=
inputs
[
0
]
x
=
inputs
[
0
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论