提交 95c835ed authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Modified choose function.

上级 2015c5a7
...@@ -5093,8 +5093,9 @@ def swapaxes(y, axis1, axis2): ...@@ -5093,8 +5093,9 @@ def swapaxes(y, axis1, axis2):
li[axis1], li[axis2] = li[axis2], li[axis1] li[axis1], li[axis2] = li[axis2], li[axis1]
return y.dimshuffle(li) return y.dimshuffle(li)
def choose(x, y, mode='raise'):
return Choose(mode)(x, y) def choose(a, choices, out=None, mode='raise'):
return Choose(mode)(a, choices, out)
class Choose(Op): class Choose(Op):
...@@ -5110,17 +5111,32 @@ class Choose(Op): ...@@ -5110,17 +5111,32 @@ class Choose(Op):
def props(self): def props(self):
return self.mode return self.mode
def make_node(self, x, y): def make_node(self, a, choices, out):
x = as_tensor_variable(x) a = as_tensor_variable(a)
y = as_tensor_variable(y) choices = as_tensor_variable(choices)
out = theano.gof.Constant(theano.gof.generic, out)
return Apply(self, [x, y], [x.type()])
def perform(self, node, inputs, (z,)): return Apply(self, [a, choices, out], [a.type()])
z[0] = numpy.choose(inputs[0], inputs[1], self.mode)
def infer_shape(self, nodes, shapes): def perform(self, node, inputs, (z, )):
return [shapes[0]] a = inputs[0]
choices = inputs[1]
out = inputs[2]
z[0] = numpy.choose(a, choices, out, self.mode)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
"""
import theano
from theano import tensor as T
from theano import function
from theano.tensor import basic
x = T.vector(dtype='int64')
y = T.matrix(dtype='int64')
z = basic.choose(x,y)
f = function([x, y], z)
"""
\ No newline at end of file
...@@ -7014,19 +7014,20 @@ class T_Power(): ...@@ -7014,19 +7014,20 @@ class T_Power():
class T_Choose(): class T_Choose():
def test_numpy_compare(self): def test_numpy_compare(self):
a = tensor.vector(dtype='int64') a = tensor.vector(dtype='int64')
b = tensor.matrix(dtype='int64') b = tensor.matrix(dtype='int64')
A = numpy.random.rand(5) A = numpy.random.random_integers(-5, 5, (4))
B = numpy.random.rand(5, 6) B = numpy.random.random_integers(-5, 5, (4, 4))
modes = ['raise', 'wrap', 'clip'] modes = ['raise', 'wrap', 'clip']
for m in modes: for m in modes:
f = function([a, b], choose(a, b, m)) f = function([a, b], choose(a, b, mode=m))
t_c = f(A, B) t_c = f(A, B)
n_c = numpy.choose(A, B, m) n_c = numpy.choose(A, B, mode=m)
assert numpy.allclose(t_p, n_p) assert numpy.allclose(t_c, n_c)
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论