提交 e200cb5c authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Rewrite local_merge_alloc to remove a non-strict zip

上级 4a3d8c7d
...@@ -1207,25 +1207,23 @@ def local_merge_alloc(fgraph, node): ...@@ -1207,25 +1207,23 @@ def local_merge_alloc(fgraph, node):
inputs_inner = node.inputs[0].owner.inputs inputs_inner = node.inputs[0].owner.inputs
dims_outer = inputs_outer[1:] dims_outer = inputs_outer[1:]
dims_inner = inputs_inner[1:] dims_inner = inputs_inner[1:]
dims_outer_rev = dims_outer[::-1] assert len(dims_inner) <= len(dims_outer)
dims_inner_rev = dims_inner[::-1]
# check if the pattern of broadcasting is matched, in the reversed ordering. # check if the pattern of broadcasting is matched, in the reversed ordering.
# The reverse ordering is needed when an Alloc add an implicit new # The reverse ordering is needed when an Alloc add an implicit new
# broadcasted dimensions to its inputs[0]. Eg: # broadcasted dimensions to its inputs[0]. Eg:
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
i = 0 for i, dim_inner in enumerate(reversed(dims_inner)):
for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev, strict=False): dim_outer = dims_outer[-1 - i]
if dim_inner != dim_outer: if dim_inner == dim_outer:
continue
if isinstance(dim_inner, Constant) and dim_inner.data == 1: if isinstance(dim_inner, Constant) and dim_inner.data == 1:
pass continue
else:
dims_outer[-1 - i] = Assert( dims_outer[-1 - i] = Assert(
"You have a shape error in your graph. To see a better" "You have a shape error in your graph. To see a better"
" error message and a stack trace of where in your code" " error message and a stack trace of where in your code"
" the error is created, use the PyTensor flags" " the error is created, use the PyTensor flags"
" optimizer=None or optimizer=fast_compile." " optimizer=None or optimizer=fast_compile."
)(dim_outer, eq(dim_outer, dim_inner)) )(dim_outer, eq(dim_outer, dim_inner))
i += 1
return [alloc(inputs_inner[0], *dims_outer)] return [alloc(inputs_inner[0], *dims_outer)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论