提交 9300a882 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

merge

......@@ -2705,6 +2705,8 @@ class Rebroadcast(Op):
broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern))
def make_node(self, x):
if x.ndim <= numpy.max(self.axis.keys()):
raise ValueError('Trying to rebroadcast inexistant dimension')
t = x.type.__class__(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)])
......
......@@ -27,7 +27,6 @@ utt.seed_rng()
def inplace_func(inputs, outputs, mode=get_default_mode()):
return function(inputs, outputs, mode=mode, accept_inplace=True)
def eval_outputs(outputs):
variables = inplace_func([], outputs)()
if len(variables) == 1:
......@@ -2611,48 +2610,55 @@ def test_autocast():
finally:
ac.__exit__()
def test_unbroadcast_addbroadcast():
"""
test that the unbroadcast fct don't insert not needed broadcast
and fuse consecutive Rebroadcast op
"""
x=matrix()
assert unbroadcast(x,0) is x
assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is x
assert unbroadcast(x,0,1) is x
assert addbroadcast(x,0) is not x
assert addbroadcast(x,1) is not x
assert addbroadcast(x,1,0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,0),0) is x
assert addbroadcast(unbroadcast(x,0),0) is not x
x=row()
assert unbroadcast(x,0) is not x
assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is not x
assert unbroadcast(x,0,1) is not x
assert addbroadcast(x,0) is x
assert addbroadcast(x,1).owner.inputs[0] is x
assert addbroadcast(x,1,0).owner.inputs[0] is x
assert addbroadcast(x,0,1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,1),1) is x
assert addbroadcast(unbroadcast(x,1),1) is not x
#the first broadcast is remove the broadcast, so the second
#should not make one
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x
#test that consecutive Rebroadcast op are fused
x=TensorType(dtype = 'float64', broadcastable = (True,True))()
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x
class test_broadcast(unittest.TestCase):
def test_broadcast_bigdim(self):
def f():
x = matrix()
addbroadcast(x,2)
self.failUnlessRaises(ValueError, f)
def test_unbroadcast_addbroadcast(self):
"""
test that the unbroadcast fct don't insert not needed broadcast
and fuse consecutive Rebroadcast op
"""
x=matrix()
assert unbroadcast(x,0) is x
assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is x
assert unbroadcast(x,0,1) is x
assert addbroadcast(x,0) is not x
assert addbroadcast(x,1) is not x
assert addbroadcast(x,1,0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,0),0) is x
assert addbroadcast(unbroadcast(x,0),0) is not x
x=row()
assert unbroadcast(x,0) is not x
assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is not x
assert unbroadcast(x,0,1) is not x
assert addbroadcast(x,0) is x
assert addbroadcast(x,1).owner.inputs[0] is x
assert addbroadcast(x,1,0).owner.inputs[0] is x
assert addbroadcast(x,0,1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,1),1) is x
assert addbroadcast(unbroadcast(x,1),1) is not x
#the first broadcast is remove the broadcast, so the second
#should not make one
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x
#test that consecutive Rebroadcast op are fused
x=TensorType(dtype = 'float64', broadcastable = (True,True))()
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x
def test_mod():
"""
We add this test as not all language and C implementation give the same
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论