提交 cddbb50e authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Saving change.

上级 4885dbb9
...@@ -5114,20 +5114,19 @@ class Choose(Op): ...@@ -5114,20 +5114,19 @@ class Choose(Op):
def make_node(self, a, choices): def make_node(self, a, choices):
a = as_tensor_variable(a) a = as_tensor_variable(a)
if isinstance(choices, tuple): if isinstance(choices, (tuple, list)):
choice = as_tensor_variable(choices) choice = theano.typed_list.make_list(choices)
return Apply(self, [a, choice], [a.type()])
else: else:
choice = as_tensor_variable(choices) choice = as_tensor_variable(choices)
return Apply(self, [a, choice], [a.type()]) return Apply(self, [a, choice], [a.type()])
def perform(self, node, inputs, (z, )): def perform(self, node, inputs, (z, )):
a = inputs[0] a = inputs[0]
choice = tuple(inputs[1]) choice = inputs[1]
z[0] = numpy.choose(a, choice, mode=self.mode) z[0] = numpy.choose(a, choice, mode=self.mode)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name
""" """
import theano import theano
...@@ -5139,7 +5138,7 @@ x = T.tensor3(dtype='int64') ...@@ -5139,7 +5138,7 @@ x = T.tensor3(dtype='int64')
y = T.tensor3(dtype='int64') y = T.tensor3(dtype='int64')
z = T.tensor3(dtype='int64') z = T.tensor3(dtype='int64')
w = choose(x,(y,z)) w = choose(x,(y,z))
f = function([x,y,z], w, on_unused_input='warn') f = function([x,y,z], w)
a = np.array([0, 1]).reshape((2,1,1)) a = np.array([0, 1]).reshape((2,1,1))
c1 = np.array([1, 2, 3]).reshape((1,3,1)) c1 = np.array([1, 2, 3]).reshape((1,3,1))
c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5)) c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5))
......
...@@ -567,3 +567,53 @@ Returns the size of a list. ...@@ -567,3 +567,53 @@ Returns the size of a list.
:param x: typed list. :param x: typed list.
""" """
class Make_List(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, a):
assert isinstance(a, (tuple, list))
a2 = []
for elem in a:
if not isinstance(elem, theano.gof.Variable):
elem = as_tensor_variable(elem)
a2.append(elem)
assert all(a2[0].type == elem.type for elem in a2)
tl = theano.typed_list.TypedListType(a2[0].type)()
return Apply(self, a2, [tl])
def perform(self, node, inputs, (out, )):
out[0] = list(inputs)
make_list = Make_List()
"""
Returns a list made from tuple's elements.
:param a: tuple.
"""
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论