提交 5b5d1d9f authored 作者: David Warde-Farley's avatar David Warde-Farley

Simplify roll graph in case of 1-d input.

In the case of a default axis argument and a 1-d tensor, don't create flatten and reshape nodes that need to be optimized out when it's simple enough not to do so.
上级 5b3b5418
......@@ -4313,8 +4313,11 @@ def roll(x, shift, axis=None):
Output tensor, with the same shape as `x`.
"""
if axis is None:
if x.ndim > 1:
y = x.flatten()
return roll(y, shift, axis=0).reshape(x.shape)
else:
axis = 0
# A slice of all elements in a dimension ':'
allslice = slice(None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论