提交 174b2093 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add infer_shape to Rebroadcast and test it.

上级 9810017d
......@@ -3289,6 +3289,8 @@ class Rebroadcast(Op):
gz, = grads
# restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis]) for axis, value in self.axis.iteritems()])(gz),
def infer_shape(self, node, ishapes):
return ishapes
def addbroadcast(x, *axes):
"""
......
......@@ -4156,6 +4156,18 @@ class test_broadcast(unittest.TestCase):
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x
def test_infer_shape(self):
x = matrix()
y = addbroadcast(x,0)
f = theano.function([x], y.shape)
f(numpy.zeros((1,5)))
topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 3
assert isinstance(topo[0].op, opt.Shape_i)
assert isinstance(topo[1].op, opt.Shape_i)
assert isinstance(topo[2].op, opt.MakeVector)
def test_mod():
"""
We add this test as not all language and C implementation give the same
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论