提交 e53c4879 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add axis argument to squeeze

上级 91e20cc4
......@@ -358,6 +358,22 @@ class TestSqueeze(utt.InferShapeTester):
assert tested.shape == expected.shape
assert np.allclose(tested, expected)
def test_axis(self):
variable = tensor.TensorType(theano.config.floatX, [False, True, False])()
res = squeeze(variable, axis=1)
assert res.broadcastable == (False, False)
variable = tensor.TensorType(theano.config.floatX, [False, True, False])()
res = squeeze(variable, axis=(1,))
assert res.broadcastable == (False, False)
variable = tensor.TensorType(theano.config.floatX, [False, True, False, True])()
res = squeeze(variable, axis=(1, 3))
assert res.broadcastable == (False, False)
class TestCompress(utt.InferShapeTester):
axis_list = [None, -1, 0, 0, 0, 1]
......
import numpy as np
import theano
from collections.abc import Collection
from theano.tensor import basic
from theano.tensor import nlinalg # noqa
from theano import gof, scalar
......@@ -575,7 +577,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
return out
def squeeze(x):
def squeeze(x, axis=None):
"""
Remove broadcastable dimensions from the shape of an array.
......@@ -590,13 +592,26 @@ def squeeze(x):
x
Input data, tensor variable.
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the
shape. If an axis is selected with shape entry greater than
one, an error is raised.
Returns
-------
object
`x` without its broadcastable dimensions.
"""
view = x.dimshuffle([i for i in range(x.ndim) if not x.broadcastable[i]])
if axis is None:
axis = range(x.ndim)
elif not isinstance(axis, Collection):
axis = (axis,)
view = x.dimshuffle(
[i for i in range(x.ndim) if not x.broadcastable[i] or i not in axis]
)
return view
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论