Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6b6ed43a
提交
6b6ed43a
authored
12月 30, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor aesara.tensor.basic_opt.local_fill_to_alloc
上级
3bd247e4
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
61 行增加
和
48 行删除
+61
-48
basic_opt.py
aesara/tensor/basic_opt.py
+32
-40
test_basic_opt.py
tests/tensor/test_basic_opt.py
+29
-8
没有找到文件。
aesara/tensor/basic_opt.py
浏览文件 @
6b6ed43a
...
@@ -1652,50 +1652,42 @@ def local_fill_sink(fgraph, node):
...
@@ -1652,50 +1652,42 @@ def local_fill_sink(fgraph, node):
@register_specialize
@register_specialize
@register_stabilize
@register_stabilize
# @register_canonicalize # We make full pass after the canonizer phase.
@local_optimizer
([
fill
])
@local_optimizer
([
fill
])
def
local_fill_to_alloc
(
fgraph
,
node
):
def
local_fill_to_alloc
(
fgraph
,
node
):
"""fill(s,v) -> alloc(v, shape(s))
r"""Remove `fill`\s or replace them with `Alloc`\s.
This is an important optimization because with the shape_to_shape_i
`Alloc`\s are preferable because they replace explicit tensor dependencies
optimization, the dependency on 's' is often removed.
with their dependencies on those tensors' shapes, and sometimes those
shapes can be computed without needing to compute the tensors themselves.
"""
if
node
.
op
==
fill
:
XXX: This rewrite can produce inconsistent results, so do *not* consider
r
,
v
=
node
.
inputs
making it a canonicalization until those inconsistencies are
if
v
.
type
==
node
.
outputs
[
0
]
.
type
:
resolved/justified.
# this is a useless fill, erase it.
"""
rval
=
[
v
]
shape_ref
,
values_ref
=
node
.
inputs
elif
v
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
out_type
=
node
.
outputs
[
0
]
.
type
# this is a cast
rval
=
[
cast
(
v
,
node
.
outputs
[
0
]
.
type
.
dtype
)]
if
values_ref
.
type
.
broadcastable
==
out_type
.
broadcastable
:
elif
r
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
# The assumption here is that `values_ref` already has the same shape
# we are broadcasting v somehow, but not r
# as `shape_ref`, so a `fill`/`Alloc` is unnecessary.
o
=
broadcast_like
(
v
,
r
,
fgraph
,
dtype
=
v
.
dtype
)
# XXX FIXME TODO: The only way this can be determined is if one
# absolutely knows that the shapes of `shape_ref` and `values_ref` are
# equal.
# This is an old rewrite, and it's only a
# "specialization/stabilization", so we're going to leave it be for
# now.
return
[
values_ref
]
if
shape_ref
.
type
.
broadcastable
==
out_type
.
broadcastable
:
# In this case, we assume that some broadcasting is needed (otherwise
# the condition above would've been true), so we replace the `fill`
# with an `Alloc`.
o
=
broadcast_like
(
values_ref
,
shape_ref
,
fgraph
,
dtype
=
values_ref
.
dtype
)
copy_stack_trace
(
node
.
outputs
[
0
],
o
)
copy_stack_trace
(
node
.
outputs
[
0
],
o
)
rval
=
[
o
]
return
[
o
]
else
:
# we are broadcasting both v and r,
# the output shape must be computed
#
# TODO: implement this case (including a test!)
#
# I think the strategy should be to extend the shorter
# shape vector with 1s (how?) and then take the
# elementwise max of the two. - how to flag an error of
# shape mismatch where broadcasting should be illegal?
return
return
# TODO: cut out un-necessary dimshuffles of v
assert
rval
[
0
]
.
type
==
node
.
outputs
[
0
]
.
type
,
(
"rval"
,
rval
[
0
]
.
type
,
"orig"
,
node
.
outputs
[
0
]
.
type
,
"node"
,
node
,
)
# aesara.printing.debugprint(node.outputs[0], file='str'))
return
rval
# Register this after stabilize at 1.5 to make sure stabilize don't
# Register this after stabilize at 1.5 to make sure stabilize don't
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
6b6ed43a
...
@@ -1292,12 +1292,10 @@ def test_local_fill_useless():
...
@@ -1292,12 +1292,10 @@ def test_local_fill_useless():
x
=
dvector
()
x
=
dvector
()
y
=
dvector
()
y
=
dvector
()
z
=
lvector
()
z
=
lvector
()
m
=
dmatrix
()
x_
=
np
.
random
.
random
((
5
,))
x_
=
np
.
random
.
random
((
5
,))
y_
=
np
.
random
.
random
((
5
,))
y_
=
np
.
random
.
random
((
5
,))
z_
=
(
np
.
random
.
random
((
5
,))
*
5
)
.
astype
(
"int64"
)
z_
=
(
np
.
random
.
random
((
5
,))
*
5
)
.
astype
(
"int64"
)
m_
=
np
.
random
.
random
((
5
,
5
))
# basic case
# basic case
f
=
function
([
x
],
at
.
fill
(
x
,
x
)
*
2
,
mode
=
mode_opt
)
f
=
function
([
x
],
at
.
fill
(
x
,
x
)
*
2
,
mode
=
mode_opt
)
...
@@ -1329,12 +1327,35 @@ def test_local_fill_useless():
...
@@ -1329,12 +1327,35 @@ def test_local_fill_useless():
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
mul
]
assert
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
==
[
mul
]
f
(
x_
,
y_
)
f
(
x_
,
y_
)
# Test with different number of dimensions
# The fill is not useless, so it should stay
def
test_local_fill_to_alloc
():
f
=
function
([
m
,
x
],
at
.
fill
(
m
,
x
)
*
2
,
mode
=
mode_opt
)
x
=
dvector
()
ops
=
[
node
.
op
.
__class__
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
m
=
dmatrix
()
assert
Alloc
in
ops
f
(
m_
,
x_
)
x_
=
np
.
random
.
random
((
5
,))
m_
=
np
.
random
.
random
((
5
,
5
))
y
=
at
.
fill
(
m
,
x
)
mode
=
mode_opt
.
including
(
"stabilize"
,
"local_fill_to_alloc"
)
.
excluding
(
"useless"
,
"local_useless_fill"
)
f
=
function
([
m
,
x
],
y
,
mode
=
mode
)
assert
Alloc
in
[
node
.
op
.
__class__
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
res
=
f
(
m_
,
x_
)
exp_res
=
np
.
broadcast_to
(
x_
,
m_
.
shape
)
assert
np
.
array_equal
(
res
,
exp_res
)
y
=
at
.
fill
(
x
,
m
)
f
=
function
([
m
,
x
],
y
,
mode
=
mode
)
assert
Alloc
not
in
[
node
.
op
.
__class__
for
node
in
f
.
maker
.
fgraph
.
toposort
()]
res
=
f
(
m_
,
x_
)
assert
np
.
array_equal
(
res
,
m_
)
class
TestLocalCanonicalizeAlloc
:
class
TestLocalCanonicalizeAlloc
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论