提交 0003c2ae authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added infer_shape for join op + a simple test of it

上级 6b3b780c
...@@ -3477,6 +3477,51 @@ class Join(Op): ...@@ -3477,6 +3477,51 @@ class Join(Op):
else: else:
return node.owner.tag.shape_zero return node.owner.tag.shape_zero
def infer_shape(self, node, ishapes):
# Join op should get at least two inputs to join
assert len(ishapes) > 1
# Not sure this is needed anymore :( ... basically the apply_shape
# version of the apply node (i.e. the one defined in
# gof/apply_shape) calls infer_shape methods passing None to unknown
# inputs. It can handle NotImplementedError, so for now I just raise
# that whenever I get a None. Should we just remove gof/apply_shape
# if it is depricated ??
if ishapes[1] is None:
raise NotImplementedError
n_dim = len(ishapes[1])
for shape in ishapes[1:]:
if shape is None:
raise NotImplementedError
for shape_i in shape:
if shape_i is None:
raise NotImplementedError
# at this point the inputs have been broadcasted so they should
# all have the same shape
assert abs(len(shape) - n_dim) == 0
out_shapes = []
for dim in xrange(n_dim):
# we have to deal with 2 possible cases in here :
# a) we are dealing with the dimension for which we join
# (called t_side from true side of the if, where the if
# compares current dimension with the joining dimension)
# b) a non joining dimension ( in which maybe a symbolic
# assertion can be used to make sure all tensors have
# the same number of elements on this non-joined dimension
# this is f_side
# initialize
t_side = ishapes[1][dim]
f_side = ishapes[1][dim]
# loop over tensors and sum for the joining dimension
for shape in ishapes[2:]:
t_side = t_side + shape[dim]
# return the dimensions found
out_shapes.append( switch(eq(dim, node.inputs[0]),
t_side, f_side))
return [tuple(out_shapes)]
@_redefine_asRoutine(Join()) @_redefine_asRoutine(Join())
def join(axis, *tensors): def join(axis, *tensors):
""" """
......
...@@ -2447,6 +2447,21 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -2447,6 +2447,21 @@ class T_Join_and_Split(unittest.TestCase):
self.assertRaises(ValueError, g, a_val, b_val, c_val, bad_d_val, e_val) self.assertRaises(ValueError, g, a_val, b_val, c_val, bad_d_val, e_val)
self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, bad_e_val) self.assertRaises(ValueError, g, a_val, b_val, c_val, d_val, bad_e_val)
def test_infer_shape_join(self):
x1 = matrix()
x2 = matrix()
x3 = matrix()
z = join(0,x1,x2,x3)
def get_mat(s1,s2):
return numpy.asarray( numpy.random.uniform(size=(s1,s2)),
dtype= config.floatX)
f = theano.function([x1,x2,x3], z.shape)
f( get_mat(3,4), get_mat(2,4), get_mat(1,5))
if theano.config.mode != 'FAST_COMPILE':
for node in f.maker.env.toposort():
assert not isinstance(node.op, tensor.Join)
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论