提交 4e0b5023 authored 作者: Frederic's avatar Frederic

[CRASH] mean(list), fix gh-3527

上级 3ccdadfd
...@@ -2992,7 +2992,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -2992,7 +2992,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
For gpu, if you specify dtype=float32, everything will be done on the gpu. For gpu, if you specify dtype=float32, everything will be done on the gpu.
""" """
input = as_tensor_variable(input)
if op: if op:
if dtype not in (None, 'float64'): if dtype not in (None, 'float64'):
raise NotImplementedError( raise NotImplementedError(
......
...@@ -4464,6 +4464,10 @@ class T_mean(unittest.TestCase): ...@@ -4464,6 +4464,10 @@ class T_mean(unittest.TestCase):
data = rand(50) data = rand(50)
assert numpy.allclose(f(data), numpy.mean(data)) assert numpy.allclose(f(data), numpy.mean(data))
def test_list(self):
ll = [theano.shared(0.), theano.shared(2.)]
tensor.mean(ll).eval() == 1
class test_matinv(unittest.TestCase): class test_matinv(unittest.TestCase):
...@@ -6090,11 +6094,16 @@ def test_var(): ...@@ -6090,11 +6094,16 @@ def test_var():
assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val)) assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val))
def test_sum_overflow(): class T_sum(unittest.TestCase):
"""Ensure that overflow errors are a little bit harder to get""" def test_sum_overflow(self):
a = Tensor(dtype='int8', broadcastable=[False])() """Ensure that overflow errors are a little bit harder to get"""
f = function([a], sum(a)) a = Tensor(dtype='int8', broadcastable=[False])()
assert f([1] * 300) == 300 f = function([a], sum(a))
assert f([1] * 300) == 300
def test_list(self):
ll = [theano.shared(0.), theano.shared(2.)]
tensor.sum(ll).eval() == 2
@dec.skipif( @dec.skipif(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论