提交 72d9e872 authored 作者: Frederic's avatar Frederic

crash fix for the Mean Op (not used by default) with the axis was a scalar.

上级 0ef4d76e
...@@ -3207,6 +3207,7 @@ def prod(input, axis=None, dtype=None, keepdims=False): ...@@ -3207,6 +3207,7 @@ def prod(input, axis=None, dtype=None, keepdims=False):
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis=None): def __init__(self, axis=None):
elemwise.CAReduce.__init__(self, scal.add, axis) elemwise.CAReduce.__init__(self, scal.add, axis)
assert self.axis is None or len(self.axis) == 1
def __str__(self): def __str__(self):
if self.axis is not None: if self.axis is not None:
...@@ -3221,7 +3222,7 @@ class Mean(elemwise.CAReduce): ...@@ -3221,7 +3222,7 @@ class Mean(elemwise.CAReduce):
def perform(self, node, inp, out): def perform(self, node, inp, out):
input, = inp input, = inp
output, = out output, = out
output[0] = numpy.mean(input, axis=self.axis) output[0] = numpy.mean(input, axis=self.axis[0])
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
if self.axis is not None: if self.axis is not None:
......
...@@ -6422,14 +6422,14 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6422,14 +6422,14 @@ class TestInferShape(utt.InferShapeTester):
# #
# to be resolved as desired/appropriate # to be resolved as desired/appropriate
"""
adtens3_val = rand(3, 4, 5) adtens3_val = rand(3, 4, 5)
aiscal_val = 2 aiscal_val = 2
self._compile_and_check([adtens3],
[Mean(None)(adtens3)],
[adtens3_val], Mean)
self._compile_and_check([adtens3], self._compile_and_check([adtens3],
[Mean(aiscal_val)(adtens3)], [Mean(aiscal_val)(adtens3)],
[adtens3_val], Mean) [adtens3_val], Mean)
"""
# IncSubtensor # IncSubtensor
# Note: Is testing only for the 4-tensor below sufficient? # Note: Is testing only for the 4-tensor below sufficient?
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论