提交 dd018d55 authored 作者: Tanjay94's avatar Tanjay94 提交者: Frederic

Added infer_shape and its test for tuple argument.

上级 c590bbc6
......@@ -5109,8 +5109,14 @@ class Choose(Op):
def __eq__(self, other):
return (type(self) == type(other) and self.props() == other.props())
def _infer_shape(self, node, shapes):
return[(shapes[0])]
def infer_shape(self, node, shapes):
if isinstance(node.inputs[1], tuple):
shape = shapes[0]
for i in range(len(shapes[0])-1):
shape[i] = shapes[1][i]
return [(shape)]
else:
return[(shapes[0])]
def props(self):
return self.mode
......@@ -5121,7 +5127,6 @@ class Choose(Op):
if isinstance(choices, (tuple, list)):
choice = theano.typed_list.make_list(choices)
else:
self.infer_shape = self._infer_shape
choice = as_tensor_variable(choices)
return Apply(self, [a, choice], [a.type()])
......@@ -5129,3 +5134,25 @@ class Choose(Op):
a = inputs[0]
choice = inputs[1]
z[0] = numpy.choose(a, choice, mode=self.mode)
"""
import theano
from theano import tensor as T
from theano import function
from theano.tensor.basic import choose
import numpy as np
x = T.tensor4(dtype='int64')
y = T.tensor4(dtype='int64')
z = T.tensor4(dtype='int64')
p = T.tensor4(dtype='int64')
w = choose(x,(y,z,p))
f = function([x,y,z,p], w)
a = np.array([0, 1, 2]).reshape((3,1,1,1))
c1 = np.array([1, 2, 3]).reshape((1,3,1,1))
c2 = np.array([-1, -2, -3, -4, -5]).reshape((1,1,5,1))
c3 = np.array([1, 2, 4]).reshape((1,1,1,3))
"""
......@@ -7073,6 +7073,20 @@ class T_Choose(utt.InferShapeTester):
# Op that should be removed from the graph.
self.op_class)
def test_infer_shape_tuple(self):
a = tensor.tensor3(dtype='int64')
b = tensor.tensor3(dtype='int64')
c = tensor.tensor3(dtype='int64')
A = numpy.asarray([1, 0], dtype='int64').reshape((2,1,1))
B = numpy.asarray(numpy.random.rand(1, 4, 1), dtype='int64')
C = numpy.asarray(numpy.random.rand(1, 1, 7), dtype='int64')
f = function([a, b, c], choose(a, (b,c)))
shape = (2, 4, 7)
assert numpy.allclose(f(A, B, C).shape, shape)
"""
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论