提交 ce827d20 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Apperently working version of general subtensor subtensor merge.

上级 321bdafa
......@@ -1232,38 +1232,58 @@ def merge_two_slices(slice1, len1, slice2, len2):
# according to the two steps we have 4 different combinations of
# positive/negative. I will denote the case I'm looking at by
# suffixes to the variables (nn,np,pn,pp):
pp_start = sl1.start + sl2.start * sl1.step
pp_stop = sl1.start + sl2.stop * sl1.step
pp_step = sl1.step * sl2.step
flen = sl2.stop - sl2.start
p_step = sl1.step * sl2.step
n_step = sl1.step * sl2.step * -1
pp_start = T.minimum(sl1.start + sl2.start * sl1.step, sl1.stop)
pp_stop = T.minimum(sl1.start + sl2.stop * sl1.step, sl1.stop)
pn_stop = sl1.start + (sl2.start -1) * sl1.step
pn_stop = T.switch(T.and_(T.lt(pn_stop,0)
, T.gt(flen,0))
, -len1 -1
, T.minimum(pn_stop, sl1.stop))
pn_start = sl1.start + (sl2.stop -1) * sl1.step
pn_start = T.minimum( pn_start, sl1.stop )
pn_start = T.maximum( pn_start, 0 )
pn_stop = sl1.start + sl2.start * sl1.step
pn_start = sl1.start + sl2.stop * sl1.step
pn_step = sl1.step * sl2.step * -1
pn_stop = T.switch(T.eq(pn_stop,-1), -len1 -1, pn_stop)
np_stop = sl1.stop - sl2.stop * sl1.step -1
np_start = sl1.stop - sl2.start * sl1.step -1
np_step = sl1.step * sl2.step * -1
np_stop = T.switch(T.eq(np_stop,-1), -len1 -1, np_stop)
np_stop = T.switch(T.and_(T.lt(np_stop,0)
, T.gt(flen,0))
,-len1-1
, T.maximum(sl1.start-1, np_stop))
np_start = T.maximum(sl1.start,sl1.stop - sl2.start * sl1.step -1)
nn_start = T.maximum(sl1.start,(sl1.stop -1)- (sl2.stop-1) * sl1.step)
nn_stop = T.maximum(sl1.start,sl1.stop - sl2.start * sl1.step)
nn_start = sl1.stop - sl2.start * sl1.step
nn_stop = sl1.stop - sl2.stop * sl1.step
nn_step = sl1.step * sl2.step
start = const_fold(T.switch(T.lt(reverse2*reverse1,0),
start = T.switch(T.lt(reverse2*reverse1,0),
T.switch(T.lt(reverse1,0), np_start, pn_start),
T.switch(T.lt(reverse1,0), nn_start,
pp_start)).owner)[0]
pp_start))
stop = const_fold(T.switch(T.lt(reverse2*reverse1,0),
stop = T.switch(T.lt(reverse2*reverse1,0),
T.switch(T.lt(reverse1,0), np_stop , pn_stop ),
T.switch(T.lt(reverse1,0), nn_stop , pp_stop
)).owner)[0]
))
step = T.switch( T.lt(reverse2*reverse1,0),n_step, p_step)
start = T.switch(T.le(flen,0), 0, start)
stop = T.switch(T.le(flen,0), 0, stop)
start = const_fold(start.owner)[0]
stop = const_fold(stop.owner)[0]
step = const_fold(step.owner)[0]
step = const_fold( T.switch(T.lt(reverse2*reverse1,0),
T.switch(T.lt(reverse1,0), np_step , pn_step ),
T.switch(T.lt(reverse1,0), nn_step , pp_step
)).owner)[0]
start = theano.printing.Print('start')(start)
stop = theano.printing.Print('stop')(stop)
step = theano.printing.Print('step')(step)
return slice(start, stop, step)
@register_canonicalize
......
......@@ -1462,6 +1462,42 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_scalar5(self):
# var[int1:][:int2]
x = TT.matrix('x')
b1 = TT.iscalar('b1')
e1 = TT.iscalar('e1')
s1 = TT.iscalar('s1')
b2 = TT.iscalar('b2')
e2 = TT.iscalar('e2')
s2 = TT.iscalar('s2')
f = function([x,b1,e1,s1,b2,e2,s2], x[b1:e1:s1][b2:e2:s2], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, TT.Subtensor)]
assert len([t for t in topo if isinstance(t.op, TT.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
b1r = self.rng.permutation(range(-8,8))[:4]
e1r = self.rng.permutation(range(-8,8))[:4]
b2r = self.rng.permutation(range(-8,8))[:4]
e2r = self.rng.permutation(range(-8,8))[:4]
s1r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:4]
s2r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:4]
for x_s in self.x_shapes:
x_val = self.rng.uniform(size=x_s).astype(config.floatX)
for b1 in b1r:
for e1 in e1r:
for s1 in s1r:
for b2 in b2r:
for e2 in e2r:
for s2 in s2r:
print >>sys.stderr, x_s,b1,e1,s1,b2,e2,s2
f(x_val, b1,e1,s1,b2,e2,s2)
def test_local_fill_useless():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论