提交 9df5e50b authored 作者: Frederic's avatar Frederic

make local_subtensor_of_alloc detect the right broadcast pattern more frequently.

reported by Ira Korshunova
上级 082f352a
...@@ -1998,7 +1998,14 @@ def local_subtensor_of_alloc(node): ...@@ -1998,7 +1998,14 @@ def local_subtensor_of_alloc(node):
# That dimension is removed. # That dimension is removed.
pass pass
else: else:
nw_dims += [T.ceil_intdiv((csl.stop - csl.start), csl.step)] nw_dim = csl.stop - csl.start
if csl.step != 1:
# Do not add the ceil_intdiv() graphs in the graphs
# when this is not needed as it prevent detecting the
# correct broadcast pattern.
nw_dim = T.ceil_intdiv(nw_dim, csl.step)
nw_dims += [nw_dim]
nw_val = val[tuple(val_slices)] nw_val = val[tuple(val_slices)]
nw_dims += dims[len(slices):] nw_dims += dims[len(slices):]
...@@ -2007,7 +2014,9 @@ def local_subtensor_of_alloc(node): ...@@ -2007,7 +2014,9 @@ def local_subtensor_of_alloc(node):
rval = T.alloc(nw_val, *nw_dims) rval = T.alloc(nw_val, *nw_dims)
if type(rval) not in (list, tuple): if type(rval) not in (list, tuple):
rval = [rval] rval = [rval]
if rval[0].type != node.outputs[0].type:
#This happen from time to time, we need to discover why
return
return rval return rval
......
...@@ -2366,11 +2366,13 @@ class Test_alloc_zero(unittest.TestCase): ...@@ -2366,11 +2366,13 @@ class Test_alloc_zero(unittest.TestCase):
def test_local_subtensor_of_alloc(): def test_local_subtensor_of_alloc():
x = tensor.matrix('x')
# DebugMode should detect if something goes wrong. # DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape. # test shape combination of odd and event shape.
for shape in [(3, 5), (4, 6), (3, 8), (4, 7)]: for shape in [(3, 5), (4, 6), (3, 8), (4, 7),
(1, 5), (5, 1)]:
x = tensor.tensor(dtype=theano.config.floatX,
broadcastable=(shape[0] == 1, shape[1] == 1))
xval = numpy.zeros(shape, dtype=config.floatX) xval = numpy.zeros(shape, dtype=config.floatX)
yval = numpy.arange(shape[1], dtype=config.floatX) yval = numpy.arange(shape[1], dtype=config.floatX)
...@@ -2387,21 +2389,29 @@ def test_local_subtensor_of_alloc(): ...@@ -2387,21 +2389,29 @@ def test_local_subtensor_of_alloc():
# Only one column # Only one column
z_vec = yx[:, 3] z_vec = yx[:, 3]
assert z_vec.ndim == 1 assert z_vec.ndim == 1
# results are vector
for slices in [ slicess = []
# results are vector if shape[0] != 1:
(slice(None), 3), slicess.append((2, slice(None)))
(2, slice(None)), if shape[1] != 1:
# results are matrix slicess.append((slice(None), 3))
# results are matrix
slicess += [
(slice(None), slice(3, None)), (slice(None), slice(3, None)),
(slice(3, None), ), (slice(3, None), ),
(slice(3, None), slice(3, None)), (slice(3, None), slice(3, None)),
(slice(1, 3), slice(None, -1)), (slice(1, 3), slice(None, -1)),
(slice(None, None, 2)), (slice(None, None, 2)),
(slice(1, None, 2)), (slice(1, None, 2)),
]: ]
for slices in slicess:
z = yx.__getitem__(slices) z = yx.__getitem__(slices)
f = theano.function([x], z) f = theano.function([x], z)
theano.printing.debugprint(f)
# if theano.config.mode != 'FAST_COMPILE':
# assert not any([isinstance(node.op, Subtensor)
# for node in f.maker.fgraph.toposort()])
val = f(xval) val = f(xval)
assert xval.__getitem__(slices).shape == val.shape assert xval.__getitem__(slices).shape == val.shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论