提交 44fa4acc authored 作者: Frederic's avatar Frederic

Fix a test(bad syntax and fix the following problem). Add basic mean test.

上级 26b235cf
...@@ -2622,7 +2622,7 @@ def mean(input, axis = None, op = False): ...@@ -2622,7 +2622,7 @@ def mean(input, axis = None, op = False):
if input.dtype == 'float32': if input.dtype == 'float32':
shp = cast(shp, 'float32') shp = cast(shp, 'float32')
if axis is None: if axis is None:
axis = range(input.type.ndim) axis = range(input.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
for i in axis: for i in axis:
......
...@@ -2977,10 +2977,17 @@ class T_divimpl(unittest.TestCase): ...@@ -2977,10 +2977,17 @@ class T_divimpl(unittest.TestCase):
class T_mean(unittest.TestCase): class T_mean(unittest.TestCase):
def test_regression_mean_of_ndarray_failure(self): def test_regression_mean_of_ndarray_failure(self):
try: try:
T.mean(numpy.zeros(1)) theano.tensor.mean(numpy.zeros(1))
except AttributeError: except AttributeError:
self.fail() self.fail()
def test0(self):
#Simple test...
x = theano.tensor.vector()
f = theano.function([x],theano.tensor.mean(x))
data = numpy.asarray(numpy.random.rand(50), dtype=config.floatX)
assert f(data) == numpy.mean(data)
# class T_abs(unittest.TestCase): # class T_abs(unittest.TestCase):
# def test_impl(self): # def test_impl(self):
# t = as_tensor_variable(1.0) # t = as_tensor_variable(1.0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论