提交 c855a6d8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid creating useless squeezes and expand_dims

上级 f8c0c4df
......@@ -603,6 +603,10 @@ def squeeze(x, axis=None):
except np.AxisError:
raise np.AxisError(axis, ndim=_x.ndim)
if not axis:
# Nothing to do
return _x
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
......
......@@ -868,7 +868,8 @@ def shape_padleft(t, n_ones=1):
"""
_t = at.as_tensor_variable(t)
if n_ones == 0:
return _t
pattern = ["x"] * n_ones + list(range(_t.type.ndim))
return _t.dimshuffle(pattern)
......@@ -884,7 +885,8 @@ def shape_padright(t, n_ones=1):
"""
_t = at.as_tensor_variable(t)
if n_ones == 0:
return _t
pattern = list(range(_t.type.ndim)) + ["x"] * n_ones
return _t.dimshuffle(pattern)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论