提交 b8270c45 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added make_list function into typed_list, fixed choose function and added…

Added make_list function into typed_list, fixed choose function and added infer_shape and test to it for the case where choice isn't a tuple.
上级 cddbb50e
...@@ -5109,14 +5109,19 @@ class Choose(Op): ...@@ -5109,14 +5109,19 @@ class Choose(Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self.props() == other.props()) return (type(self) == type(other) and self.props() == other.props())
def _infer_shape(self, node, shapes):
return[(shapes[0])]
def props(self): def props(self):
return self.mode return self.mode
def make_node(self, a, choices): def make_node(self, a, choices):
from theano import typed_list
a = as_tensor_variable(a) a = as_tensor_variable(a)
if isinstance(choices, (tuple, list)): if isinstance(choices, (tuple, list)):
choice = theano.typed_list.make_list(choices) choice = theano.typed_list.make_list(choices)
else: else:
self.infer_shape = self._infer_shape
choice = as_tensor_variable(choices) choice = as_tensor_variable(choices)
return Apply(self, [a, choice], [a.type()]) return Apply(self, [a, choice], [a.type()])
...@@ -5125,8 +5130,6 @@ class Choose(Op): ...@@ -5125,8 +5130,6 @@ class Choose(Op):
choice = inputs[1] choice = inputs[1]
z[0] = numpy.choose(a, choice, mode=self.mode) z[0] = numpy.choose(a, choice, mode=self.mode)
def __str__(self):
return self.__class__.__name
""" """
import theano import theano
...@@ -5139,7 +5142,7 @@ y = T.tensor3(dtype='int64') ...@@ -5139,7 +5142,7 @@ y = T.tensor3(dtype='int64')
z = T.tensor3(dtype='int64') z = T.tensor3(dtype='int64')
w = choose(x,(y,z)) w = choose(x,(y,z))
f = function([x,y,z], w) f = function([x,y,z], w)
a = np.array([0, 1]).reshape((2,1,1)) a = np.array([0, 1, 0, 1, 0, 1]).reshape((2,3,1))
c1 = np.array([1, 2, 3]).reshape((1,3,1)) c1 = np.array([1, 2, 3]).reshape((1,3,1))
c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5)) c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5))
""" """
\ No newline at end of file
...@@ -7032,6 +7032,47 @@ class T_Choose(utt.InferShapeTester): ...@@ -7032,6 +7032,47 @@ class T_Choose(utt.InferShapeTester):
n_c = numpy.choose(A, B, mode=m) n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c) assert numpy.allclose(t_c, n_c)
def test_numpy_compare_tuple(self):
a = tensor.tensor3(dtype='int64')
b = tensor.tensor3(dtype='int64')
c = tensor.tensor3(dtype='int64')
A = numpy.asarray(numpy.random.rand(2, 1, 1), dtype='int64')
B = numpy.asarray(numpy.random.rand(1, 6, 1), dtype='int64')
C = numpy.asarray(numpy.random.rand(1, 1, 5), dtype='int64')
f = function([a, b, c], choose(a, (b, c)))
t_c = f(A, B, C)
n_c = numpy.choose(A, (B, C))
assert numpy.allclose(t_c, n_c)
def test_infer_shape(self):
a = tensor.matrix(dtype='int64')
b = tensor.vector(dtype='int64')
c = tensor.matrix(dtype='int64')
d = tensor.vector(dtype='int64')
A = numpy.asarray(numpy.random.rand(5, 4), dtype='int64')
B = numpy.asarray(numpy.random.rand(4), dtype='int64')
C = numpy.asarray(numpy.random.rand(7, 4), dtype='int64')
D = numpy.asarray(numpy.random.rand(4), dtype='int64')
var1 = [a, b, a, b]
var2 = [c, d, b, a]
mat1 = [A, B, A, B]
mat2 = [C, D, B, A]
for v, m, w, n in zip(var1, mat1, var2, mat2):
self._compile_and_check([v, w], # theano.function inputs
[self.op(v, w)], # theano.function outputs
# Always use not square matrix!
# inputs data
[m, n],
# Op that should be removed from the graph.
self.op_class)
""" """
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -10,7 +10,7 @@ from theano.tensor.type_other import SliceType ...@@ -10,7 +10,7 @@ from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse, Append, Extend, Remove, Reverse,
Index, Count, Length) Index, Count, Length, make_list, Make_List)
from theano import sparse from theano import sparse
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
# TODO, handle the case where scipy isn't installed. # TODO, handle the case where scipy isn't installed.
...@@ -553,3 +553,32 @@ class test_length(unittest.TestCase): ...@@ -553,3 +553,32 @@ class test_length(unittest.TestCase):
x = rand_ranged_matrix(-1000, 1000, [100, 101]) x = rand_ranged_matrix(-1000, 1000, [100, 101])
self.assertTrue(f([x, x]) == 2) self.assertTrue(f([x, x]) == 2)
class T_Make_List(unittest.TestCase):
def test_wrong_shape(self):
a = T.vector()
b = T.matrix()
self.assertRaises(AssertionError, make_list, (a,b))
def correct_answer(self):
a = T.matrix()
b = T.matrix()
x = T.tensor3()
y = T.tensor3()
A = numpy.random.rand(5)
B = numpy.random.rand(7)
X = numpy.random.rand(5,6)
Y = numpy.random.rand(1,9)
c = make_list((a, b))
z = make_list((x, y))
fc = function([a, b], c)
fz = function([x, y], z)
self.assertTrue(f([A, B]) == [A, B])
self.assertTrue(f([X, Y]) == [X, Y])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论