提交 4b8320c7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where T.mean(ndarray) crashes

上级 1a6a0dbd
......@@ -2624,7 +2624,7 @@ def mean(input, axis = None, op = False):
axis = [axis]
for i in axis:
s = s / shp[i]
if input.dtype.startswith('float'):
if str(input.dtype).startswith('float'):
assert input.dtype == s.dtype
return s
......
......@@ -1021,6 +1021,8 @@ class CAReduce(Op):
return input_dtype
def make_node(self, input):
input = as_tensor_variable(input)
if self.axis is not None:
for axis in self.axis:
if axis >= input.type.ndim or (axis<0 and abs(axis)>input.type.ndim):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论