提交 c971bfd9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Following code review optimized Rebroadcast infer_shape.

Insert constant 1 when the dimensions is broadcastable.
上级 dafbb7fa
...@@ -3292,7 +3292,16 @@ class Rebroadcast(Op): ...@@ -3292,7 +3292,16 @@ class Rebroadcast(Op):
# restore the broadcasting pattern of the input # restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis]) for axis, value in self.axis.iteritems()])(gz), return Rebroadcast(*[(axis, x.type.broadcastable[axis]) for axis, value in self.axis.iteritems()])(gz),
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
return ishapes assert len(ishapes)==1
l = []
one = constant(1)
for ax in range(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def addbroadcast(x, *axes): def addbroadcast(x, *axes):
""" """
......
...@@ -4173,9 +4173,19 @@ class test_broadcast(unittest.TestCase): ...@@ -4173,9 +4173,19 @@ class test_broadcast(unittest.TestCase):
def test_infer_shape(self): def test_infer_shape(self):
x = matrix() x = matrix()
y = addbroadcast(x,0) y = addbroadcast(x, 0)
f = theano.function([x], y.shape) f = theano.function([x], y.shape)
f(numpy.zeros((1,5), dtype=config.floatX)) assert (f(numpy.zeros((1,5), dtype=config.floatX)) == [1,5]).all()
topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 2
assert isinstance(topo[0].op, opt.Shape_i)
assert isinstance(topo[1].op, opt.MakeVector)
x = matrix()
y = unbroadcast(x, 0)
f = theano.function([x], y.shape)
assert (f(numpy.zeros((2,5), dtype=config.floatX)) == [2,5]).all()
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE': if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 3 assert len(topo) == 3
...@@ -4183,6 +4193,18 @@ class test_broadcast(unittest.TestCase): ...@@ -4183,6 +4193,18 @@ class test_broadcast(unittest.TestCase):
assert isinstance(topo[1].op, opt.Shape_i) assert isinstance(topo[1].op, opt.Shape_i)
assert isinstance(topo[2].op, opt.MakeVector) assert isinstance(topo[2].op, opt.MakeVector)
x = row()
y = unbroadcast(x, 0)
f = theano.function([x], y.shape)
assert (f(numpy.zeros((1,5), dtype=config.floatX)) == [1,5]).all()
topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 2
assert isinstance(topo[0].op, opt.Shape_i)
assert isinstance(topo[1].op, opt.MakeVector)
def test_mod(): def test_mod():
""" """
We add this test as not all language and C implementation give the same We add this test as not all language and C implementation give the same
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论