提交 88157f68 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1539 from jlowin/patch-1

Add case where 'axis' is a numpy scalar array
...@@ -1337,9 +1337,12 @@ class MaxAndArgmax(Op): ...@@ -1337,9 +1337,12 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
if isinstance(axis, int): if isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list)): elif isinstance(axis, (tuple, list)):
axis = [int(a) for a in axis]
if len(axis) != 1: if len(axis) != 1:
axis = list(axis) axis = list(axis)
for idx in range(len(axis)): for idx in range(len(axis)):
...@@ -1497,8 +1500,12 @@ def makeKeepDims(x, y, axis): ...@@ -1497,8 +1500,12 @@ def makeKeepDims(x, y, axis):
if axis is None: if axis is None:
axis = range(x.type.ndim) axis = range(x.type.ndim)
elif isinstance(axis, int): elif isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
newaxis = [] newaxis = []
for a in axis: for a in axis:
if not isinstance(a, int): if not isinstance(a, int):
...@@ -2761,8 +2768,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -2761,8 +2768,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
if axis is None: if axis is None:
axis = range(input.ndim) axis = range(input.ndim)
elif isinstance(axis, int): elif isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
# This sequential division will possibly be optimized by Theano: # This sequential division will possibly be optimized by Theano:
for i in axis: for i in axis:
...@@ -2793,8 +2804,12 @@ def var(input, axis=None, keepdims=False): ...@@ -2793,8 +2804,12 @@ def var(input, axis=None, keepdims=False):
input_ndim = input.type.ndim input_ndim = input.type.ndim
if axis is None: if axis is None:
axis = range(input_ndim) axis = range(input_ndim)
if isinstance(axis, int): elif isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
# compute the axis-wise mean # compute the axis-wise mean
mean_input = mean(input, axis, keepdims=True) mean_input = mean(input, axis, keepdims=True)
......
...@@ -1250,8 +1250,10 @@ class CAReduce(Op): ...@@ -1250,8 +1250,10 @@ class CAReduce(Op):
# See <http://projects.scipy.org/numpy/ticket/2235>. # See <http://projects.scipy.org/numpy/ticket/2235>.
elif isinstance(axis, (int, numpy.integer)): elif isinstance(axis, (int, numpy.integer)):
self.axis = (axis,) self.axis = (axis,)
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
self.axis = (int(axis),)
else: else:
self.axis = list(set(axis)) self.axis = list(set(int(a) for a in axis))
self.axis.sort() self.axis.sort()
self.axis = tuple(self.axis) self.axis = tuple(self.axis)
......
...@@ -98,6 +98,43 @@ class test_DimShuffle(unittest_tools.InferShapeTester): ...@@ -98,6 +98,43 @@ class test_DimShuffle(unittest_tools.InferShapeTester):
y = x.dimshuffle(('x',) * (numpy.MAXDIMS + 1)) y = x.dimshuffle(('x',) * (numpy.MAXDIMS + 1))
self.assertRaises(ValueError, y.eval, {x: 0}) self.assertRaises(ValueError, y.eval, {x: 0})
class test_reduce_axes(unittest.TestCase):
def test_sum_axes(self):
axes = [None, 0, 1, [0, 1], numpy.array(1), [numpy.array(0), numpy.array(1)]]
for a in axes:
x = tensor.matrix()
m = x.sum(a)
def test_mean_axes(self):
axes = [None, 0, 1, [0, 1], numpy.array(1), [numpy.array(0), numpy.array(1)]]
for a in axes:
x = tensor.matrix()
m = x.mean(a)
def test_max_axes(self):
axes = [None, 0, 1, [0, 1], numpy.array(1), [numpy.array(0), numpy.array(1)]]
for a in axes:
x = tensor.matrix()
m = x.max(a)
def test_min_axes(self):
axes = [None, 0, 1, [0, 1], numpy.array(1), [numpy.array(0), numpy.array(1)]]
for a in axes:
x = tensor.matrix()
m = x.min(a)
def test_argmax_axes(self):
axes = [None, 0, 1, [0, 1], numpy.array(1), [numpy.array(0), numpy.array(1)]]
for a in axes:
x = tensor.matrix()
m = x.argmax(a)
def test_var_axes(self):
axes = [None, 0, 1, [0, 1], numpy.array(1), [numpy.array(0), numpy.array(1)]]
for a in axes:
x = tensor.matrix()
m = x.var(a)
class test_Broadcast(unittest.TestCase): class test_Broadcast(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论