提交 2015c5a7 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added test to choose.

上级 59a99da3
...@@ -5124,19 +5124,3 @@ class Choose(Op): ...@@ -5124,19 +5124,3 @@ class Choose(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power, stacklists, DimShuffle, hessian, ptp, power,
swapaxes swapaxes, choose
) )
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -6984,6 +6984,7 @@ class T_swapaxes(unittest.TestCase): ...@@ -6984,6 +6984,7 @@ class T_swapaxes(unittest.TestCase):
t_s = fn(a) t_s = fn(a)
assert numpy.allclose(n_s, t_s) assert numpy.allclose(n_s, t_s)
class T_Power(): class T_Power():
def test_numpy_compare(self): def test_numpy_compare(self):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
...@@ -7010,16 +7011,22 @@ class T_Power(): ...@@ -7010,16 +7011,22 @@ class T_Power():
f = function([x], z) f = function([x], z)
self.assertRaise(ValueError, f, [1, 2, 3, 4]) self.assertRaise(ValueError, f, [1, 2, 3, 4])
class T_Choose():
def test_numpy_compare(self): def test_numpy_compare(self):
rng = numpy.random.RandomState(utt.fetch_seed()) a = tensor.vector(dtype='int64')
A = tensor.matrix("A", dtype=theano.config.floatX) b = tensor.matrix(dtype='int64')
Q = power(A, 2)
fn = function([A], [Q])
a = rng.rand(4, 4).astype(theano.config.floatX)
n_p = numpy.power(a, 2) A = numpy.random.rand(5)
t_p = fn(a) B = numpy.random.rand(5, 6)
assert numpy.allclose(n_s, t_s)
modes = ['raise', 'wrap', 'clip']
for m in modes:
f = function([a, b], choose(a, b, m))
t_c = f(A, B)
n_c = numpy.choose(A, B, m)
assert numpy.allclose(t_p, n_p)
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论