提交 c7504999 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

identify and fix more reductions that don't support scalar array axes

上级 8c8ee26e
...@@ -1337,9 +1337,12 @@ class MaxAndArgmax(Op): ...@@ -1337,9 +1337,12 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
if isinstance(axis, int): if isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list)): elif isinstance(axis, (tuple, list)):
axis = [int(a) for a in axis]
if len(axis) != 1: if len(axis) != 1:
axis = list(axis) axis = list(axis)
for idx in range(len(axis)): for idx in range(len(axis)):
...@@ -1497,8 +1500,12 @@ def makeKeepDims(x, y, axis): ...@@ -1497,8 +1500,12 @@ def makeKeepDims(x, y, axis):
if axis is None: if axis is None:
axis = range(x.type.ndim) axis = range(x.type.ndim)
elif isinstance(axis, int): elif isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
newaxis = [] newaxis = []
for a in axis: for a in axis:
if not isinstance(a, int): if not isinstance(a, int):
...@@ -2761,8 +2768,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -2761,8 +2768,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
if axis is None: if axis is None:
axis = range(input.ndim) axis = range(input.ndim)
elif isinstance(axis, int): elif isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
# This sequential division will possibly be optimized by Theano: # This sequential division will possibly be optimized by Theano:
for i in axis: for i in axis:
...@@ -2793,8 +2804,12 @@ def var(input, axis=None, keepdims=False): ...@@ -2793,8 +2804,12 @@ def var(input, axis=None, keepdims=False):
input_ndim = input.type.ndim input_ndim = input.type.ndim
if axis is None: if axis is None:
axis = range(input_ndim) axis = range(input_ndim)
if isinstance(axis, int): elif isinstance(axis, (int, numpy.integer)):
axis = [axis] axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
else:
axis = [int(a) for a in axis]
# compute the axis-wise mean # compute the axis-wise mean
mean_input = mean(input, axis, keepdims=True) mean_input = mean(input, axis, keepdims=True)
......
...@@ -1253,7 +1253,7 @@ class CAReduce(Op): ...@@ -1253,7 +1253,7 @@ class CAReduce(Op):
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0: elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
self.axis = (int(axis),) self.axis = (int(axis),)
else: else:
self.axis = list(set(axis)) self.axis = list(set(int(a) for a in axis))
self.axis.sort() self.axis.sort()
self.axis = tuple(self.axis) self.axis = tuple(self.axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论