提交 4885dbb9 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Modified Choose function tu support tensor.

上级 a36a2b6b
...@@ -5112,31 +5112,23 @@ class Choose(Op): ...@@ -5112,31 +5112,23 @@ class Choose(Op):
def props(self): def props(self):
return self.mode return self.mode
def infer_shape(self, nodes, shapes):
return [(shapes[0])]
def make_node(self, a, choices): def make_node(self, a, choices):
a = as_tensor_variable(a) a = as_tensor_variable(a)
if isinstance(choices, tuple): if isinstance(choices, tuple):
choices1 = as_tensor_variable(choices[0]) choice = as_tensor_variable(choices)
choices2 = as_tensor_variable(choices[1]) return Apply(self, [a, choice], [a.type()])
return Apply(self, [a, choices1, choices2], [a.type()])
else: else:
choices = as_tensor_variable(choices) choice = as_tensor_variable(choices)
return Apply(self, [a, choices], [a.type()]) return Apply(self, [a, choice], [a.type()])
def perform(self, node, inputs, (z, )): def perform(self, node, inputs, (z, )):
a = inputs[0] a = inputs[0]
if len(inputs)>2: choice = tuple(inputs[1])
choices1 = inputs[1] z[0] = numpy.choose(a, choice, mode=self.mode)
choices2 = inputs[2]
z[0] = numpy.choose(a, (choices1, choices2), mode=self.mode)
else:
choices = inputs[1]
z[0] = numpy.choose(a, (choices1, choices2), mode=self.mode)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
""" """
import theano import theano
from theano import tensor as T from theano import tensor as T
...@@ -5147,7 +5139,7 @@ x = T.tensor3(dtype='int64') ...@@ -5147,7 +5139,7 @@ x = T.tensor3(dtype='int64')
y = T.tensor3(dtype='int64') y = T.tensor3(dtype='int64')
z = T.tensor3(dtype='int64') z = T.tensor3(dtype='int64')
w = choose(x,(y,z)) w = choose(x,(y,z))
f = function([x,y,z], w) f = function([x,y,z], w, on_unused_input='warn')
a = np.array([0, 1]).reshape((2,1,1)) a = np.array([0, 1]).reshape((2,1,1))
c1 = np.array([1, 2, 3]).reshape((1,3,1)) c1 = np.array([1, 2, 3]).reshape((1,3,1))
c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5)) c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5))
......
...@@ -7032,32 +7032,6 @@ class T_Choose(utt.InferShapeTester): ...@@ -7032,32 +7032,6 @@ class T_Choose(utt.InferShapeTester):
n_c = numpy.choose(A, B, mode=m) n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c) assert numpy.allclose(t_c, n_c)
def test_infer_shape(self):
a = tensor.matrix(dtype='int64')
b = tensor.vector(dtype='int64')
c = tensor.matrix(dtype='int64')
d = tensor.vector(dtype='int64')
A = numpy.asarray(numpy.random.rand(5, 4), dtype='int64')
B = numpy.asarray(numpy.random.rand(4), dtype='int64')
C = numpy.asarray(numpy.random.rand(7, 4), dtype='int64')
D = numpy.asarray(numpy.random.rand(4), dtype='int64')
var1 = [a, b, a, b]
var2 = [c, d, b, a]
mat1 = [A, B, A, B]
mat2 = [C, D, B, A]
for v, m, w, n in zip(var1, mat1, var2, mat2):
self._compile_and_check([v, w], # theano.function inputs
[self.op(v, w)], # theano.function outputs
# Always use not square matrix!
# inputs data
[m, n],
# Op that should be removed from the graph.
self.op_class)
""" """
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论