Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8d5a8c8c
提交
8d5a8c8c
authored
8月 12, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make better use of constants in broadcast_shape_iter
上级
cfc931fa
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
90 行增加
和
27 行删除
+90
-27
extra_ops.py
aesara/tensor/extra_ops.py
+60
-26
test_extra_ops.py
tests/tensor/test_extra_ops.py
+30
-1
没有找到文件。
aesara/tensor/extra_ops.py
浏览文件 @
8d5a8c8c
from
collections.abc
import
Collection
from
functools
import
reduce
from
typing
import
Iterable
,
Tuple
,
Union
from
typing
import
Iterable
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy.core.numeric
...
...
@@ -14,7 +14,7 @@ from aesara.gradient import (
disconnected_type
,
grad_undefined
,
)
from
aesara.graph.basic
import
Apply
,
Variable
,
equal_computations
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
,
equal_computations
from
aesara.graph.op
import
Op
from
aesara.link.c.op
import
COp
from
aesara.link.c.params_type
import
ParamsType
...
...
@@ -1491,7 +1491,12 @@ def broadcast_shape_iter(
array_shapes
=
[
(
one_at
,)
*
(
max_dims
-
len
(
a
))
+
tuple
(
one_at
if
getattr
(
sh
,
"value"
,
sh
)
==
1
else
sh
for
sh
in
a
)
+
tuple
(
one_at
if
getattr
(
sh
,
"value"
,
sh
)
==
1
else
(
aes
.
as_scalar
(
sh
)
if
not
isinstance
(
sh
,
Variable
)
else
sh
)
for
sh
in
a
)
for
a
in
arrays
]
else
:
...
...
@@ -1523,32 +1528,61 @@ def broadcast_shape_iter(
else
:
# More than one shape might not be broadcastable in this dimension
all_dims_equal
=
all
(
# TODO FIXME: This is a largely deficient means of comparing graphs
# (and especially shapes)
equal_computations
([
maybe_non_bcast_shapes
[
0
]],
[
dim
])
for
dim
in
maybe_non_bcast_shapes
[
1
:]
)
nonconst_nb_shapes
:
Set
[
int
]
=
set
()
const_nb_shapes
:
Set
[
Variable
]
=
set
()
for
shape
in
maybe_non_bcast_shapes
:
if
isinstance
(
shape
,
Constant
):
const_nb_shapes
.
add
(
shape
.
value
.
item
())
else
:
nonconst_nb_shapes
.
add
(
shape
)
if
all_dims_equal
:
result_dims
.
append
(
maybe_non_bcast_shapes
[
0
])
continue
if
len
(
const_nb_shapes
)
>
1
:
raise
ValueError
(
"Could not broadcast dimensions"
)
elif
len
(
const_nb_shapes
)
==
1
:
(
const_nb_shape
,)
=
const_nb_shapes
non_bcast_vec
=
[
aes
.
switch
(
aes
.
eq
(
nbv
,
1
),
-
one_at
,
nbv
)
for
nbv
in
maybe_non_bcast_shapes
]
dim_max
=
aes
.
abs
(
reduce
(
aes
.
scalar_maximum
,
non_bcast_vec
))
assert
const_nb_shape
!=
1
assert_dim
=
Assert
(
"Could not broadcast dimensions"
)
assert_cond
=
reduce
(
aes
.
and_
,
(
aes
.
or_
(
aes
.
eq
(
nbv
,
-
one_at
),
aes
.
eq
(
nbv
,
dim_max
))
for
nbv
in
non_bcast_vec
),
)
bcast_dim
=
assert_dim
(
dim_max
,
assert_cond
)
const_nt_shape_var
=
aesara
.
scalar
.
ScalarConstant
(
aesara
.
scalar
.
int64
,
const_nb_shape
)
if
len
(
nonconst_nb_shapes
)
>
0
:
assert_dim
=
Assert
(
"Could not broadcast dimensions"
)
assert_cond
=
reduce
(
aes
.
and_
,
(
aes
.
eq
(
nbv
,
const_nt_shape_var
)
for
nbv
in
nonconst_nb_shapes
),
)
bcast_dim
=
assert_dim
(
const_nt_shape_var
,
assert_cond
)
else
:
bcast_dim
=
const_nt_shape_var
else
:
all_dims_equal
=
all
(
# TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes)
equal_computations
([
maybe_non_bcast_shapes
[
0
]],
[
dim
])
for
dim
in
maybe_non_bcast_shapes
[
1
:]
)
if
all_dims_equal
:
result_dims
.
append
(
maybe_non_bcast_shapes
[
0
])
continue
non_bcast_vec
=
[
aes
.
switch
(
aes
.
eq
(
nbv
,
1
),
-
one_at
,
nbv
)
for
nbv
in
maybe_non_bcast_shapes
]
dim_max
=
aes
.
abs
(
reduce
(
aes
.
scalar_maximum
,
non_bcast_vec
))
assert_dim
=
Assert
(
"Could not broadcast dimensions"
)
assert_cond
=
reduce
(
aes
.
and_
,
(
aes
.
or_
(
aes
.
eq
(
nbv
,
-
one_at
),
aes
.
eq
(
nbv
,
dim_max
))
for
nbv
in
non_bcast_vec
),
)
bcast_dim
=
assert_dim
(
dim_max
,
assert_cond
)
result_dims
.
append
(
bcast_dim
)
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
8d5a8c8c
...
...
@@ -8,7 +8,7 @@ from aesara import function
from
aesara
import
tensor
as
at
from
aesara.compile.mode
import
Mode
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
applys_between
from
aesara.graph.basic
import
Constant
,
applys_between
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.raise_op
import
Assert
from
aesara.tensor.elemwise
import
DimShuffle
...
...
@@ -1143,6 +1143,35 @@ def test_broadcast_shape_basic():
assert
isinstance
(
b_at
[
-
1
]
.
owner
.
op
,
Assert
)
def
test_broadcast_shape_constants
():
"""Make sure `broadcast_shape` uses constants when it can."""
x1_shp_at
=
iscalar
(
"x1"
)
y2_shp_at
=
iscalar
(
"y2"
)
b_at
=
broadcast_shape
((
x1_shp_at
,
2
),
(
3
,
y2_shp_at
),
arrays_are_shapes
=
True
)
assert
len
(
b_at
)
==
2
assert
isinstance
(
b_at
[
0
]
.
owner
.
op
,
Assert
)
assert
b_at
[
0
]
.
owner
.
inputs
[
0
]
.
value
.
item
()
==
3
assert
isinstance
(
b_at
[
1
]
.
owner
.
op
,
Assert
)
assert
b_at
[
1
]
.
owner
.
inputs
[
0
]
.
value
.
item
()
==
2
b_at
=
broadcast_shape
((
1
,
2
),
(
3
,
2
),
arrays_are_shapes
=
True
)
assert
len
(
b_at
)
==
2
assert
all
(
isinstance
(
x
,
Constant
)
for
x
in
b_at
)
assert
b_at
[
0
]
.
value
.
item
()
==
3
assert
b_at
[
1
]
.
value
.
item
()
==
2
b_at
=
broadcast_shape
((
1
,),
(
1
,
1
),
arrays_are_shapes
=
True
)
assert
len
(
b_at
)
==
2
assert
all
(
isinstance
(
x
,
Constant
)
for
x
in
b_at
)
assert
b_at
[
0
]
.
value
.
item
()
==
1
assert
b_at
[
1
]
.
value
.
item
()
==
1
b_at
=
broadcast_shape
((
1
,),
(
1
,),
arrays_are_shapes
=
True
)
assert
len
(
b_at
)
==
1
assert
all
(
isinstance
(
x
,
Constant
)
for
x
in
b_at
)
assert
b_at
[
0
]
.
value
.
item
()
==
1
@pytest.mark.parametrize
(
(
"s1_vals"
,
"s2_vals"
,
"exp_res"
),
[
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论