提交 74c89acd authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added infer_shape test for choose function.

上级 89c8b664
......@@ -46,7 +46,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose
swapaxes, choose, Choose
)
from theano.tests import unittest_tools as utt
......@@ -7012,7 +7012,10 @@ class T_Power():
self.assertRaise(ValueError, f, [1, 2, 3, 4])
class T_Choose():
class T_Choose(utt.InferShapeTester):
op = staticmethod(choose)
op_class = Choose
def test_numpy_compare(self):
a = tensor.vector(dtype='int64')
......@@ -7036,25 +7039,24 @@ class T_Choose():
c = tensor.matrix(dtype='int64')
d = tensor.vector(dtype='int64')
A = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
A = numpy.asarray(numpy.random.rand(5, 4), dtype='int64')
B = numpy.asarray(numpy.random.rand(4), dtype='int64')
C = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
C = numpy.asarray(numpy.random.rand(7, 4), dtype='int64')
D = numpy.asarray(numpy.random.rand(4), dtype='int64')
fa = function([a, c], choose(a, c))
fb = function([b, d], choose(b, d))
fc = function([a, b], choose(a, b))
fd = function([b, a], choose(b, a))
t_ca = fa(A, C)
t_cb = fb(B, D)
t_cc = fc(A, B)
t_cd = fd(B, A)
assert numpy.allclose(A.shape, t_ca.shape)
assert numpy.allclose(B.shape, t_cb.shape)
assert numpy.allclose(A.shape, t_cc.shape)
assert numpy.allclose(B.shape, t_cd.shape)
var1 = [a, b, a, b]
var2 = [c, d, b, a]
mat1 = [A, B, A, B]
mat2 = [C, D, B, A]
for v, m, w, n in zip(var1, mat1, var2, mat2):
self._compile_and_check([v, w], # theano.function inputs
[self.op(v, w)], # theano.function outputs
# Always use not square matrix!
# inputs data
[m, n],
# Op that should be removed from the graph.
self.op_class)
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论