提交 74c89acd authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added infer_shape test for choose function.

上级 89c8b664
...@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power, stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose swapaxes, choose, Choose
) )
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -7012,7 +7012,10 @@ class T_Power(): ...@@ -7012,7 +7012,10 @@ class T_Power():
self.assertRaise(ValueError, f, [1, 2, 3, 4]) self.assertRaise(ValueError, f, [1, 2, 3, 4])
class T_Choose(): class T_Choose(utt.InferShapeTester):
op = staticmethod(choose)
op_class = Choose
def test_numpy_compare(self): def test_numpy_compare(self):
a = tensor.vector(dtype='int64') a = tensor.vector(dtype='int64')
...@@ -7036,25 +7039,24 @@ class T_Choose(): ...@@ -7036,25 +7039,24 @@ class T_Choose():
c = tensor.matrix(dtype='int64') c = tensor.matrix(dtype='int64')
d = tensor.vector(dtype='int64') d = tensor.vector(dtype='int64')
A = numpy.asarray(numpy.random.rand(4, 4), dtype='int64') A = numpy.asarray(numpy.random.rand(5, 4), dtype='int64')
B = numpy.asarray(numpy.random.rand(4), dtype='int64') B = numpy.asarray(numpy.random.rand(4), dtype='int64')
C = numpy.asarray(numpy.random.rand(4, 4), dtype='int64') C = numpy.asarray(numpy.random.rand(7, 4), dtype='int64')
D = numpy.asarray(numpy.random.rand(4), dtype='int64') D = numpy.asarray(numpy.random.rand(4), dtype='int64')
fa = function([a, c], choose(a, c)) var1 = [a, b, a, b]
fb = function([b, d], choose(b, d)) var2 = [c, d, b, a]
fc = function([a, b], choose(a, b)) mat1 = [A, B, A, B]
fd = function([b, a], choose(b, a)) mat2 = [C, D, B, A]
t_ca = fa(A, C) for v, m, w, n in zip(var1, mat1, var2, mat2):
t_cb = fb(B, D) self._compile_and_check([v, w], # theano.function inputs
t_cc = fc(A, B) [self.op(v, w)], # theano.function outputs
t_cd = fd(B, A) # Always use not square matrix!
# inputs data
assert numpy.allclose(A.shape, t_ca.shape) [m, n],
assert numpy.allclose(B.shape, t_cb.shape) # Op that should be removed from the graph.
assert numpy.allclose(A.shape, t_cc.shape) self.op_class)
assert numpy.allclose(B.shape, t_cd.shape)
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论