提交 5961fe59 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added infer_shape to choose function and deleted the out input.

上级 4ff01916
......@@ -5095,7 +5095,8 @@ def swapaxes(y, axis1, axis2):
def choose(a, choices, out=None, mode='raise'):
return Choose(mode)(a, choices, out)
assert out is None
return Choose(mode)(a, choices)
class Choose(Op):
......@@ -5111,29 +5112,20 @@ class Choose(Op):
def props(self):
return self.mode
def make_node(self, a, choices, out):
def infer_shape(self, nodes, shapes):
return [(shapes[0])]
def make_node(self, a, choices):
a = as_tensor_variable(a)
choices = as_tensor_variable(choices)
out = theano.gof.Constant(theano.gof.generic, out)
return Apply(self, [a, choices, out], [a.type()])
return Apply(self, [a, choices], [a.type()])
def perform(self, node, inputs, (z, )):
a = inputs[0]
choices = inputs[1]
out = inputs[2]
z[0] = numpy.choose(a, choices, out, self.mode)
z[0] = numpy.choose(a, choices, mode=self.mode)
def __str__(self):
return self.__class__.__name__
"""
import theano
from theano import tensor as T
from theano import function
from theano.tensor import basic
x = T.vector(dtype='int64')
y = T.matrix(dtype='int64')
z = basic.choose(x,y)
f = function([x, y], z)
"""
\ No newline at end of file
......@@ -7029,7 +7029,6 @@ class T_Choose():
n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_c, n_c)
def wrong_choice_array(self):
a = tensor.matrix(dtype='int64')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论