提交 433fba5c authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added choose to theano to copy numpy.choose.

上级 95c835ed
...@@ -5127,9 +5127,6 @@ class Choose(Op): ...@@ -5127,9 +5127,6 @@ class Choose(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
""" """
import theano import theano
from theano import tensor as T from theano import tensor as T
......
...@@ -7018,8 +7018,8 @@ class T_Choose(): ...@@ -7018,8 +7018,8 @@ class T_Choose():
a = tensor.vector(dtype='int64') a = tensor.vector(dtype='int64')
b = tensor.matrix(dtype='int64') b = tensor.matrix(dtype='int64')
A = numpy.random.random_integers(-5, 5, (4)) A = numpy.asarray(numpy.random.rand(4), dtype='int64')
B = numpy.random.random_integers(-5, 5, (4, 4)) B = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
modes = ['raise', 'wrap', 'clip'] modes = ['raise', 'wrap', 'clip']
...@@ -7029,6 +7029,18 @@ class T_Choose(): ...@@ -7029,6 +7029,18 @@ class T_Choose():
n_c = numpy.choose(A, B, mode=m) n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c) assert numpy.allclose(t_c, n_c)
def wrong_choice_array(self):
a = tensor.matrix(dtype='int64')
b = tensor.vector(dtype='int64')
A = numpy.asarray(numpy.random.rand(4), dtype='int64')
B = numpy.asarray(numpy.random.rand(4, 4), dtype='int64')
f = function([a, b], choose(a, b))
self.assertRaise(ValueError, f, A, B)
""" """
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论