提交 d8fbe1e9 authored 作者: Frederic's avatar Frederic

Some infer_shape changes.

上级 a0eb7306
......@@ -5103,14 +5103,23 @@ def choose(a, choices, out=None, mode='raise'):
class Choose(Op):
__props__ = ('mode',)
def __init__(self, mode):
assert mode in ("raise", "wrap", "clip")
self.mode = mode
def infer_shape(self, node, shapes):
if isinstance(node.inputs[1], tuple):
if isinstance(node.inputs[1], TensorVariable):
return[(shapes[0])]
else:
import theano.typed_list
assert isinstance(node.inputs[1], theano.typed_list.TypedListVariable)
#import pdb;pdb.set_trace()
raise ShapeError("")
shape = shapes[0]
for i in range(len(shapes[0])-1):
shape[i] = shapes[1][i]
return [(shape)]
else:
return[(shapes[0])]
def make_node(self, a, choices):
from theano import typed_list
......@@ -5124,4 +5133,5 @@ class Choose(Op):
def perform(self, node, inputs, (z, )):
a = inputs[0]
choice = inputs[1]
# TODO reuse out?
z[0] = numpy.choose(a, choice, mode=self.mode)
......@@ -7052,7 +7052,7 @@ class T_Choose(utt.InferShapeTester):
a = tensor.matrix(dtype='int64')
b = tensor.vector(dtype='int64')
c = tensor.matrix(dtype='int64')
d = tensor.vector(dtype='int64')
d = tensor.vector(dtype='int64')
A = numpy.asarray(numpy.random.rand(5, 4), dtype='int64')
B = numpy.asarray(numpy.random.rand(4), dtype='int64')
......@@ -7073,20 +7073,29 @@ class T_Choose(utt.InferShapeTester):
# Op that should be removed from the graph.
self.op_class)
def test_infer_shape_tuple(self):
# Disabled as it isn't implemented.
def ___test_infer_shape_tuple(self):
a = tensor.tensor3(dtype='int64')
b = tensor.tensor3(dtype='int64')
c = tensor.tensor3(dtype='int64')
A = numpy.asarray([1, 0], dtype='int64').reshape((2,1,1))
A = numpy.asarray([1, 0], dtype='int64').reshape((2, 1, 1))
B = numpy.asarray(numpy.random.rand(1, 4, 1), dtype='int64')
C = numpy.asarray(numpy.random.rand(1, 1, 7), dtype='int64')
f = function([a, b, c], choose(a, (b,c)))
f = function([a, b, c], choose(a, (b, c)))
shape = (2, 4, 7)
assert numpy.allclose(f(A, B, C).shape, shape)
self._compile_and_check([a, b, c], # theano.function inputs
[self.op(a, (b, c))], # theano.function outputs
# Always use not square matrix!
# inputs data
[A, B, C],
# Op that should be removed from the graph.
self.op_class)
"""
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论