提交 4b8320c7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where T.mean(ndarray) crashes

上级 1a6a0dbd
...@@ -2624,7 +2624,7 @@ def mean(input, axis = None, op = False): ...@@ -2624,7 +2624,7 @@ def mean(input, axis = None, op = False):
axis = [axis] axis = [axis]
for i in axis: for i in axis:
s = s / shp[i] s = s / shp[i]
if input.dtype.startswith('float'): if str(input.dtype).startswith('float'):
assert input.dtype == s.dtype assert input.dtype == s.dtype
return s return s
......
...@@ -1021,6 +1021,8 @@ class CAReduce(Op): ...@@ -1021,6 +1021,8 @@ class CAReduce(Op):
return input_dtype return input_dtype
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input)
if self.axis is not None: if self.axis is not None:
for axis in self.axis: for axis in self.axis:
if axis >= input.type.ndim or (axis<0 and abs(axis)>input.type.ndim): if axis >= input.type.ndim or (axis<0 and abs(axis)>input.type.ndim):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论