Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9df54a54
提交
9df54a54
authored
6月 06, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement shape inference for boolean advanced indexing
上级
f6407da2
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
158 行增加
和
39 行删除
+158
-39
extra_ops.py
pytensor/tensor/extra_ops.py
+38
-15
subtensor.py
pytensor/tensor/subtensor.py
+41
-20
test_extra_ops.py
tests/tensor/test_extra_ops.py
+9
-1
test_subtensor.py
tests/tensor/test_subtensor.py
+70
-3
没有找到文件。
pytensor/tensor/extra_ops.py
浏览文件 @
9df54a54
...
@@ -21,14 +21,18 @@ from pytensor.misc.safe_asarray import _asarray
...
@@ -21,14 +21,18 @@ from pytensor.misc.safe_asarray import _asarray
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
from
pytensor.scalar
import
int32
as
int_t
from
pytensor.scalar
import
int32
as
int_t
from
pytensor.scalar
import
upcast
from
pytensor.scalar
import
upcast
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor
import
basic
as
at
from
pytensor.tensor
import
basic
as
at
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
abs
as
a
t_abs
from
pytensor.tensor.math
import
abs
as
p
t_abs
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
eq
as
pt_eq
from
pytensor.tensor.math
import
eq
as
pt_eq
from
pytensor.tensor.math
import
ge
,
lt
,
maximum
,
minimum
,
prod
from
pytensor.tensor.math
import
ge
,
lt
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
maximum
,
minimum
,
prod
from
pytensor.tensor.math
import
sum
as
at_sum
from
pytensor.tensor.math
import
sum
as
at_sum
from
pytensor.tensor.math
import
switch
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.var
import
TensorVariable
from
pytensor.tensor.var
import
TensorVariable
...
@@ -1063,7 +1067,7 @@ class FillDiagonalOffset(Op):
...
@@ -1063,7 +1067,7 @@ class FillDiagonalOffset(Op):
# only valid for matrices
# only valid for matrices
wr_a
=
fill_diagonal_offset
(
grad
,
0
,
offset
)
wr_a
=
fill_diagonal_offset
(
grad
,
0
,
offset
)
offset_abs
=
a
t_abs
(
offset
)
offset_abs
=
p
t_abs
(
offset
)
pos_offset_flag
=
ge
(
offset
,
0
)
pos_offset_flag
=
ge
(
offset
,
0
)
neg_offset_flag
=
lt
(
offset
,
0
)
neg_offset_flag
=
lt
(
offset
,
0
)
min_wh
=
minimum
(
width
,
height
)
min_wh
=
minimum
(
width
,
height
)
...
@@ -1442,6 +1446,7 @@ _broadcast_assert = Assert(
...
@@ -1442,6 +1446,7 @@ _broadcast_assert = Assert(
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
"inform PyTensor of a known shape."
)
)
_runtime_broadcast_assert
=
Assert
(
"Could not broadcast dimensions."
)
def
broadcast_shape
(
*
arrays
,
**
kwargs
)
->
Tuple
[
aes
.
ScalarVariable
,
...
]:
def
broadcast_shape
(
*
arrays
,
**
kwargs
)
->
Tuple
[
aes
.
ScalarVariable
,
...
]:
...
@@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
...
@@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
def
broadcast_shape_iter
(
def
broadcast_shape_iter
(
arrays
:
Iterable
[
Union
[
TensorVariable
,
Tuple
[
TensorVariable
,
...
]]],
arrays
:
Iterable
[
Union
[
TensorVariable
,
Tuple
[
TensorVariable
,
...
]]],
arrays_are_shapes
:
bool
=
False
,
arrays_are_shapes
:
bool
=
False
,
allow_runtime_broadcast
:
bool
=
False
,
)
->
Tuple
[
aes
.
ScalarVariable
,
...
]:
)
->
Tuple
[
aes
.
ScalarVariable
,
...
]:
r"""Compute the shape resulting from broadcasting arrays.
r"""Compute the shape resulting from broadcasting arrays.
...
@@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
...
@@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
arrays
arrays
An iterable of tensors, or a tuple of shapes (as tuples),
An iterable of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
for which the broadcast shape is computed.
arrays_are_shapes
arrays_are_shapes
: bool, default False
Indicates whether or not the `arrays` contains shape tuples.
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value ``1``--or simply the integer
are (scalar) constants with the value ``1``--or simply the integer
``1``.
``1``. This is not revelant if `allow_runtime_broadcast` is True.
allow_runtime_broadcast: bool, default False
Whether to allow non-statically known broadcast on the shape computation.
"""
"""
one
_at
=
pytensor
.
scalar
.
ScalarConstant
(
pytensor
.
scalar
.
int64
,
1
)
one
=
pytensor
.
scalar
.
ScalarConstant
(
pytensor
.
scalar
.
int64
,
1
)
if
arrays_are_shapes
:
if
arrays_are_shapes
:
max_dims
=
max
(
len
(
a
)
for
a
in
arrays
)
max_dims
=
max
(
len
(
a
)
for
a
in
arrays
)
array_shapes
=
[
array_shapes
=
[
(
one
_at
,)
*
(
max_dims
-
len
(
a
))
(
one
,)
*
(
max_dims
-
len
(
a
))
+
tuple
(
+
tuple
(
one
_at
one
if
sh
==
1
or
isinstance
(
sh
,
Constant
)
and
sh
.
value
==
1
if
sh
==
1
or
isinstance
(
sh
,
Constant
)
and
sh
.
value
==
1
else
(
aes
.
as_scalar
(
sh
)
if
not
isinstance
(
sh
,
Variable
)
else
sh
)
else
(
aes
.
as_scalar
(
sh
)
if
not
isinstance
(
sh
,
Variable
)
else
sh
)
for
sh
in
a
for
sh
in
a
...
@@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
...
@@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
_arrays
=
tuple
(
at
.
as_tensor_variable
(
a
)
for
a
in
arrays
)
_arrays
=
tuple
(
at
.
as_tensor_variable
(
a
)
for
a
in
arrays
)
array_shapes
=
[
array_shapes
=
[
(
one_at
,)
*
(
max_dims
-
a
.
ndim
)
(
one
,)
*
(
max_dims
-
a
.
ndim
)
+
tuple
(
+
tuple
(
one
if
t_sh
==
1
else
sh
for
sh
,
t_sh
in
zip
(
a
.
shape
,
a
.
type
.
shape
))
one_at
if
t_sh
==
1
else
sh
for
sh
,
t_sh
in
zip
(
a
.
shape
,
a
.
type
.
shape
)
)
for
a
in
_arrays
for
a
in
_arrays
]
]
...
@@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
...
@@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
for
dim_shapes
in
zip
(
*
array_shapes
):
for
dim_shapes
in
zip
(
*
array_shapes
):
# Get the shapes in this dimension that are not broadcastable
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
# (i.e. not symbolically known to be broadcastable)
non_bcast_shapes
=
[
shape
for
shape
in
dim_shapes
if
shape
!=
one
_at
]
non_bcast_shapes
=
[
shape
for
shape
in
dim_shapes
if
shape
!=
one
]
if
len
(
non_bcast_shapes
)
==
0
:
if
len
(
non_bcast_shapes
)
==
0
:
# Every shape was broadcastable in this dimension
# Every shape was broadcastable in this dimension
result_dims
.
append
(
one
_at
)
result_dims
.
append
(
one
)
elif
len
(
non_bcast_shapes
)
==
1
:
elif
len
(
non_bcast_shapes
)
==
1
:
# Only one shape might not be broadcastable in this dimension
# Only one shape might not be broadcastable in this dimension
result_dims
.
extend
(
non_bcast_shapes
)
result_dims
.
extend
(
non_bcast_shapes
)
...
@@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
...
@@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
result_dims
.
append
(
first_length
)
result_dims
.
append
(
first_length
)
continue
continue
if
not
allow_runtime_broadcast
:
# Add assert that all remaining shapes are equal
# Add assert that all remaining shapes are equal
condition
=
pt_all
([
pt_eq
(
first_length
,
other
)
for
other
in
other_lengths
])
condition
=
pt_all
(
[
pt_eq
(
first_length
,
other
)
for
other
in
other_lengths
]
)
result_dims
.
append
(
_broadcast_assert
(
first_length
,
condition
))
result_dims
.
append
(
_broadcast_assert
(
first_length
,
condition
))
else
:
lengths
=
as_tensor_variable
((
first_length
,
*
other_lengths
))
runtime_broadcastable
=
pt_eq
(
lengths
,
one
)
result_dim
=
pt_abs
(
pt_max
(
switch
(
runtime_broadcastable
,
-
one
,
lengths
))
)
condition
=
pt_all
(
switch
(
~
runtime_broadcastable
,
pt_eq
(
lengths
,
result_dim
),
np
.
array
(
True
),
)
)
result_dims
.
append
(
_runtime_broadcast_assert
(
result_dim
,
condition
))
return
tuple
(
result_dims
)
return
tuple
(
result_dims
)
...
...
pytensor/tensor/subtensor.py
浏览文件 @
9df54a54
...
@@ -20,15 +20,11 @@ from pytensor.misc.safe_asarray import _asarray
...
@@ -20,15 +20,11 @@ from pytensor.misc.safe_asarray import _asarray
from
pytensor.printing
import
Printer
,
pprint
,
set_precedence
from
pytensor.printing
import
Printer
,
pprint
,
set_precedence
from
pytensor.scalar.basic
import
ScalarConstant
from
pytensor.scalar.basic
import
ScalarConstant
from
pytensor.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
pytensor.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
pytensor.tensor.basic
import
alloc
,
get_underlying_scalar_constant_value
from
pytensor.tensor.basic
import
alloc
,
get_underlying_scalar_constant_value
,
nonzero
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.exceptions
import
(
from
pytensor.tensor.exceptions
import
AdvancedIndexingError
,
NotScalarConstantError
AdvancedIndexingError
,
NotScalarConstantError
,
ShapeError
,
)
from
pytensor.tensor.math
import
clip
from
pytensor.tensor.math
import
clip
from
pytensor.tensor.shape
import
Reshape
,
specify_broadcastable
from
pytensor.tensor.shape
import
Reshape
,
s
hape_i
,
s
pecify_broadcastable
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
TensorType
,
TensorType
,
bscalar
,
bscalar
,
...
@@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
...
@@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
from
pytensor.tensor.extra_ops
import
broadcast_shape
from
pytensor.tensor.extra_ops
import
broadcast_shape
res_shape
+=
broadcast_shape
(
res_shape
+=
broadcast_shape
(
*
grp_indices
,
arrays_are_shapes
=
indices_are_shapes
*
grp_indices
,
arrays_are_shapes
=
indices_are_shapes
,
# The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting.
# As long as that is true, the shape inference has to respect that this is not an error.
allow_runtime_broadcast
=
True
,
)
)
res_shape
+=
tuple
(
array_shape
[
dim
]
for
dim
in
remaining_dims
)
res_shape
+=
tuple
(
array_shape
[
dim
]
for
dim
in
remaining_dims
)
...
@@ -2584,26 +2584,47 @@ class AdvancedSubtensor(Op):
...
@@ -2584,26 +2584,47 @@ class AdvancedSubtensor(Op):
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
indices
=
node
.
inputs
[
1
:]
def
is_bool_index
(
idx
):
index_shapes
=
list
(
ishapes
[
1
:])
return
(
for
i
,
idx
in
enumerate
(
indices
):
if
(
isinstance
(
idx
,
(
np
.
bool_
,
bool
))
isinstance
(
idx
,
(
np
.
bool_
,
bool
))
or
getattr
(
idx
,
"dtype"
,
None
)
==
"bool"
or
getattr
(
idx
,
"dtype"
,
None
)
==
"bool"
):
)
raise
ShapeError
(
"Shape inference for boolean indices is not implemented"
indices
=
node
.
inputs
[
1
:]
index_shapes
=
[]
for
idx
,
ishape
in
zip
(
indices
,
ishapes
[
1
:]):
# Mixed bool indexes are converted to nonzero entries
if
is_bool_index
(
idx
):
index_shapes
.
extend
(
(
shape_i
(
nz_dim
,
0
,
fgraph
=
fgraph
),)
for
nz_dim
in
nonzero
(
idx
)
)
)
# The `ishapes` entries for `SliceType`s will be None, and
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
# we need to give `indexed_result_shape` the actual slices.
if
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
):
elif
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
):
index_shapes
[
i
]
=
idx
index_shapes
.
append
(
idx
)
else
:
index_shapes
.
append
(
ishape
)
res_shape
=
indexed_result_shape
(
res_shape
=
list
(
i
shapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
i
ndexed_result_shape
(
ishapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
)
)
)
adv_indices
=
[
idx
for
idx
in
indices
if
not
is_basic_idx
(
idx
)]
bool_indices
=
[
idx
for
idx
in
adv_indices
if
is_bool_index
(
idx
)]
# Special logic when the only advanced index group is of bool type.
# We can replace the nonzeros by a sum of the whole bool variable.
if
len
(
bool_indices
)
==
1
and
len
(
adv_indices
)
==
1
:
[
bool_index
]
=
bool_indices
# Find the output dim associated with the bool index group
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim
=
indices
.
index
(
bool_index
)
res_shape
[
start_dim
]
=
bool_index
.
sum
()
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
return
[
list
(
res_shape
)
]
return
[
res_shape
]
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
9df54a54
...
@@ -1087,9 +1087,17 @@ def test_broadcast_shape_basic():
...
@@ -1087,9 +1087,17 @@ def test_broadcast_shape_basic():
assert
any
(
assert
any
(
isinstance
(
node
.
op
,
Assert
)
for
node
in
applys_between
([
x_at
,
y_at
],
b_at
)
isinstance
(
node
.
op
,
Assert
)
for
node
in
applys_between
([
x_at
,
y_at
],
b_at
)
)
)
# This should fail because it would need dynamic broadcasting
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
assert
np
.
array_equal
([
z
.
eval
()
for
z
in
b_at
],
b
.
shape
)
assert
np
.
array_equal
([
z
.
eval
()
for
z
in
b_at
],
b
.
shape
)
# But fine if we allow_runtime_broadcast
b_at
=
broadcast_shape
(
shape_tuple
(
x_at
,
use_bcast
=
False
),
shape_tuple
(
y_at
,
use_bcast
=
False
),
arrays_are_shapes
=
True
,
allow_runtime_broadcast
=
True
,
)
assert
np
.
array_equal
([
z
.
eval
()
for
z
in
b_at
],
b
.
shape
)
# Or if static bcast is known
b_at
=
broadcast_shape
(
shape_tuple
(
x_at
),
shape_tuple
(
y_at
),
arrays_are_shapes
=
True
)
b_at
=
broadcast_shape
(
shape_tuple
(
x_at
),
shape_tuple
(
y_at
),
arrays_are_shapes
=
True
)
assert
np
.
array_equal
([
z
.
eval
()
for
z
in
b_at
],
b
.
shape
)
assert
np
.
array_equal
([
z
.
eval
()
for
z
in
b_at
],
b
.
shape
)
...
...
tests/tensor/test_subtensor.py
浏览文件 @
9df54a54
...
@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
...
@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
tensor
,
tensor
,
tensor3
,
tensor3
,
tensor4
,
tensor4
,
tensor5
,
vector
,
vector
,
)
)
from
pytensor.tensor.type_other
import
NoneConst
,
SliceConstant
,
make_slice
,
slicetype
from
pytensor.tensor.type_other
import
NoneConst
,
SliceConstant
,
make_slice
,
slicetype
...
@@ -2150,6 +2151,12 @@ class TestAdvancedSubtensor:
...
@@ -2150,6 +2151,12 @@ class TestAdvancedSubtensor:
class
TestInferShape
(
utt
.
InferShapeTester
):
class
TestInferShape
(
utt
.
InferShapeTester
):
@staticmethod
def
random_bool_mask
(
shape
,
rng
=
None
):
if
rng
is
None
:
rng
=
np
.
random
.
default_rng
()
return
rng
.
binomial
(
n
=
1
,
p
=
0.5
,
size
=
shape
)
.
astype
(
bool
)
def
test_IncSubtensor
(
self
):
def
test_IncSubtensor
(
self
):
admat
=
dmatrix
()
admat
=
dmatrix
()
bdmat
=
dmatrix
()
bdmat
=
dmatrix
()
...
@@ -2439,25 +2446,85 @@ class TestInferShape(utt.InferShapeTester):
...
@@ -2439,25 +2446,85 @@ class TestInferShape(utt.InferShapeTester):
n
=
dmatrix
()
n
=
dmatrix
()
n_val
=
np
.
arange
(
6
)
.
reshape
((
2
,
3
))
n_val
=
np
.
arange
(
6
)
.
reshape
((
2
,
3
))
#
infer_shape is not implemented, but it should not crash
#
Shape inference requires runtime broadcasting between the nonzero() shapes
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
n
],
[
n
],
[
n
[
n
[:,
0
]
>
2
,
n
[
0
,
:]
>
2
]],
[
n
[
n
[:,
0
]
>
2
,
n
[
0
,
:]
>
2
]],
[
n_val
],
[
n_val
],
AdvancedSubtensor
,
AdvancedSubtensor
,
check_topo
=
False
,
)
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
n
],
[
n
],
[
n
[
n
[:,
0
]
>
2
]],
[
n
[
n
[:,
0
]
>
2
]],
[
n_val
],
[
n_val
],
AdvancedSubtensor
,
AdvancedSubtensor
,
check_topo
=
False
,
)
self
.
_compile_and_check
(
[
n
],
[
n
[:,
np
.
array
([
True
,
False
,
True
])]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
[
n
[
np
.
array
([
False
,
False
]),
1
:]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
[
n
[
np
.
array
([
True
,
True
]),
0
]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
[
n
[
self
.
random_bool_mask
(
n_val
.
shape
)]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
[
n
[
None
,
self
.
random_bool_mask
(
n_val
.
shape
),
None
]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
[
n
[
slice
(
5
,
None
),
self
.
random_bool_mask
(
n_val
.
shape
[
1
])]],
[
n_val
],
AdvancedSubtensor
,
)
)
abs_res
=
n
[
~
isinf
(
n
)]
abs_res
=
n
[
~
isinf
(
n
)]
assert
abs_res
.
type
.
shape
==
(
None
,)
assert
abs_res
.
type
.
shape
==
(
None
,)
def
test_AdvancedSubtensor_bool_mixed
(
self
):
n
=
tensor5
(
"x"
,
dtype
=
"float64"
)
shape
=
(
18
,
3
,
4
,
5
,
6
)
n_val
=
np
.
arange
(
np
.
prod
(
shape
))
.
reshape
(
shape
)
self
.
_compile_and_check
(
[
n
],
# Consecutive advanced index
[
n
[
1
:,
self
.
random_bool_mask
((
3
,
4
)),
0
,
1
:]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
# Non-consecutive advanced index
[
n
[
1
:,
self
.
random_bool_mask
((
3
,
4
)),
1
:,
0
]],
[
n_val
],
AdvancedSubtensor
,
)
self
.
_compile_and_check
(
[
n
],
# Non-consecutive advanced index
[
n
[
1
:,
self
.
random_bool_mask
((
3
,)),
1
:,
None
,
np
.
zeros
((
6
,
1
),
dtype
=
int
)]],
[
n_val
],
AdvancedSubtensor
,
)
@config.change_flags
(
compute_test_value
=
"raise"
)
@config.change_flags
(
compute_test_value
=
"raise"
)
def
test_basic_shape
():
def
test_basic_shape
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论