提交 a36a2b6b authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added options to have a tuple of choices.

上级 74c89acd
...@@ -5117,15 +5117,38 @@ class Choose(Op): ...@@ -5117,15 +5117,38 @@ class Choose(Op):
def make_node(self, a, choices): def make_node(self, a, choices):
a = as_tensor_variable(a) a = as_tensor_variable(a)
choices = as_tensor_variable(choices) if isinstance(choices, tuple):
choices1 = as_tensor_variable(choices[0])
return Apply(self, [a, choices], [a.type()]) choices2 = as_tensor_variable(choices[1])
return Apply(self, [a, choices1, choices2], [a.type()])
else:
choices = as_tensor_variable(choices)
return Apply(self, [a, choices], [a.type()])
def perform(self, node, inputs, (z, )): def perform(self, node, inputs, (z, )):
a = inputs[0] a = inputs[0]
choices = inputs[1] if len(inputs)>2:
choices1 = inputs[1]
z[0] = numpy.choose(a, choices, mode=self.mode) choices2 = inputs[2]
z[0] = numpy.choose(a, (choices1, choices2), mode=self.mode)
else:
choices = inputs[1]
z[0] = numpy.choose(a, (choices1, choices2), mode=self.mode)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
"""
import theano
from theano import tensor as T
from theano import function
from theano.tensor.basic import choose
import numpy as np
x = T.tensor3(dtype='int64')
y = T.tensor3(dtype='int64')
z = T.tensor3(dtype='int64')
w = choose(x,(y,z))
f = function([x,y,z], w)
a = np.array([0, 1]).reshape((2,1,1))
c1 = np.array([1, 2, 3]).reshape((1,3,1))
c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5))
"""
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论