提交 89c8b664 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added infer_shape test for choose function.

上级 4e9d23c5
...@@ -11,7 +11,7 @@ from itertools import izip ...@@ -11,7 +11,7 @@ from itertools import izip
# Import builtin min to be able to use it after importing the tensor version. # Import builtin min to be able to use it after importing the tensor version.
import __builtin__ import __builtin__
builtin_min = __builtin__.min builtin_min = __builtin__.min
from nose.tools import assert_raises
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
import numpy import numpy
...@@ -7029,16 +7029,32 @@ class T_Choose(): ...@@ -7029,16 +7029,32 @@ class T_Choose():
n_c = numpy.choose(A, B, mode=m) n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c) assert numpy.allclose(t_c, n_c)
def wrong_choice_array(self): def test_infer_shape(self):
a = tensor.matrix(dtype='int64') a = tensor.matrix(dtype='int64')
b = tensor.vector(dtype='int64') b = tensor.vector(dtype='int64')
c = tensor.matrix(dtype='int64')
A = numpy.asarray(numpy.random.rand(4), dtype='int64') d = tensor.vector(dtype='int64')
B = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
A = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
f = function([a, b], choose(a, b)) B = numpy.asarray(numpy.random.rand(4), dtype='int64')
self.assertRaise(ValueError, f, A, B) C = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
D = numpy.asarray(numpy.random.rand(4), dtype='int64')
fa = function([a, c], choose(a, c))
fb = function([b, d], choose(b, d))
fc = function([a, b], choose(a, b))
fd = function([b, a], choose(b, a))
t_ca = fa(A, C)
t_cb = fb(B, D)
t_cc = fc(A, B)
t_cd = fd(B, A)
assert numpy.allclose(A.shape, t_ca.shape)
assert numpy.allclose(B.shape, t_cb.shape)
assert numpy.allclose(A.shape, t_cc.shape)
assert numpy.allclose(B.shape, t_cd.shape)
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论