提交 6e470998 authored 作者: Frederic's avatar Frederic

Small update to new MakeList op.

上级 d8fbe1e9
...@@ -569,7 +569,7 @@ Returns the size of a list. ...@@ -569,7 +569,7 @@ Returns the size of a list.
""" """
class Make_List(Op): class MakeList(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -584,7 +584,9 @@ class Make_List(Op): ...@@ -584,7 +584,9 @@ class Make_List(Op):
if not isinstance(elem, theano.gof.Variable): if not isinstance(elem, theano.gof.Variable):
elem = as_tensor_variable(elem) elem = as_tensor_variable(elem)
a2.append(elem) a2.append(elem)
assert all(a2[0].type == elem.type for elem in a2) if not all(a2[0].type == elem.type for elem in a2):
raise TypeError(
"MakeList need all input variable to be of the same type.")
tl = theano.typed_list.TypedListType(a2[0].type)() tl = theano.typed_list.TypedListType(a2[0].type)()
return Apply(self, a2, [tl]) return Apply(self, a2, [tl])
...@@ -592,9 +594,11 @@ class Make_List(Op): ...@@ -592,9 +594,11 @@ class Make_List(Op):
def perform(self, node, inputs, (out, )): def perform(self, node, inputs, (out, )):
out[0] = list(inputs) out[0] = list(inputs)
make_list = Make_List() make_list = MakeList()
""" """
Returns a list made from tuple's elements. Build a Python list from those Theano variable.
:param a: tuple. :param a: tuple/list of Theano variable
:note: All Theano variable must have the same type.
""" """
...@@ -10,7 +10,7 @@ from theano.tensor.type_other import SliceType ...@@ -10,7 +10,7 @@ from theano.tensor.type_other import SliceType
from theano.typed_list.type import TypedListType from theano.typed_list.type import TypedListType
from theano.typed_list.basic import (GetItem, Insert, from theano.typed_list.basic import (GetItem, Insert,
Append, Extend, Remove, Reverse, Append, Extend, Remove, Reverse,
Index, Count, Length, make_list, Make_List) Index, Count, Length, make_list, MakeList)
from theano import sparse from theano import sparse
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
# TODO, handle the case where scipy isn't installed. # TODO, handle the case where scipy isn't installed.
...@@ -555,13 +555,13 @@ class test_length(unittest.TestCase): ...@@ -555,13 +555,13 @@ class test_length(unittest.TestCase):
self.assertTrue(f([x, x]) == 2) self.assertTrue(f([x, x]) == 2)
class T_Make_List(unittest.TestCase): class T_MakeList(unittest.TestCase):
def test_wrong_shape(self): def test_wrong_shape(self):
a = T.vector() a = T.vector()
b = T.matrix() b = T.matrix()
self.assertRaises(AssertionError, make_list, (a,b)) self.assertRaises(TypeError, make_list, (a,b))
def correct_answer(self): def correct_answer(self):
a = T.matrix() a = T.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论