提交 d8fbe1e9 authored 作者: Frederic's avatar Frederic

Some infer_shape changes.

上级 a0eb7306
...@@ -5103,14 +5103,23 @@ def choose(a, choices, out=None, mode='raise'): ...@@ -5103,14 +5103,23 @@ def choose(a, choices, out=None, mode='raise'):
class Choose(Op): class Choose(Op):
__props__ = ('mode',) __props__ = ('mode',)
def __init__(self, mode):
assert mode in ("raise", "wrap", "clip")
self.mode = mode
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
if isinstance(node.inputs[1], tuple):
if isinstance(node.inputs[1], TensorVariable):
return[(shapes[0])]
else:
import theano.typed_list
assert isinstance(node.inputs[1], theano.typed_list.TypedListVariable)
#import pdb;pdb.set_trace()
raise ShapeError("")
shape = shapes[0] shape = shapes[0]
for i in range(len(shapes[0])-1): for i in range(len(shapes[0])-1):
shape[i] = shapes[1][i] shape[i] = shapes[1][i]
return [(shape)] return [(shape)]
else:
return[(shapes[0])]
def make_node(self, a, choices): def make_node(self, a, choices):
from theano import typed_list from theano import typed_list
...@@ -5124,4 +5133,5 @@ class Choose(Op): ...@@ -5124,4 +5133,5 @@ class Choose(Op):
def perform(self, node, inputs, (z, )): def perform(self, node, inputs, (z, )):
a = inputs[0] a = inputs[0]
choice = inputs[1] choice = inputs[1]
# TODO reuse out?
z[0] = numpy.choose(a, choice, mode=self.mode) z[0] = numpy.choose(a, choice, mode=self.mode)
...@@ -7073,20 +7073,29 @@ class T_Choose(utt.InferShapeTester): ...@@ -7073,20 +7073,29 @@ class T_Choose(utt.InferShapeTester):
# Op that should be removed from the graph. # Op that should be removed from the graph.
self.op_class) self.op_class)
def test_infer_shape_tuple(self): # Disabled as it isn't implemented.
def ___test_infer_shape_tuple(self):
a = tensor.tensor3(dtype='int64') a = tensor.tensor3(dtype='int64')
b = tensor.tensor3(dtype='int64') b = tensor.tensor3(dtype='int64')
c = tensor.tensor3(dtype='int64') c = tensor.tensor3(dtype='int64')
A = numpy.asarray([1, 0], dtype='int64').reshape((2,1,1)) A = numpy.asarray([1, 0], dtype='int64').reshape((2, 1, 1))
B = numpy.asarray(numpy.random.rand(1, 4, 1), dtype='int64') B = numpy.asarray(numpy.random.rand(1, 4, 1), dtype='int64')
C = numpy.asarray(numpy.random.rand(1, 1, 7), dtype='int64') C = numpy.asarray(numpy.random.rand(1, 1, 7), dtype='int64')
f = function([a, b, c], choose(a, (b,c))) f = function([a, b, c], choose(a, (b, c)))
shape = (2, 4, 7) shape = (2, 4, 7)
assert numpy.allclose(f(A, B, C).shape, shape) assert numpy.allclose(f(A, B, C).shape, shape)
self._compile_and_check([a, b, c], # theano.function inputs
[self.op(a, (b, c))], # theano.function outputs
# Always use not square matrix!
# inputs data
[A, B, C],
# Op that should be removed from the graph.
self.op_class)
""" """
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论