提交 9300a882 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

merge

...@@ -2705,6 +2705,8 @@ class Rebroadcast(Op): ...@@ -2705,6 +2705,8 @@ class Rebroadcast(Op):
broadcast_pattern[k] = str(int(v)) broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern)) return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern))
def make_node(self, x): def make_node(self, x):
if x.ndim <= numpy.max(self.axis.keys()):
raise ValueError('Trying to rebroadcast inexistant dimension')
t = x.type.__class__(dtype = x.type.dtype, t = x.type.__class__(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b) broadcastable = [self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)]) for i, b in enumerate(x.type.broadcastable)])
......
...@@ -27,7 +27,6 @@ utt.seed_rng() ...@@ -27,7 +27,6 @@ utt.seed_rng()
def inplace_func(inputs, outputs, mode=get_default_mode()): def inplace_func(inputs, outputs, mode=get_default_mode()):
return function(inputs, outputs, mode=mode, accept_inplace=True) return function(inputs, outputs, mode=mode, accept_inplace=True)
def eval_outputs(outputs): def eval_outputs(outputs):
variables = inplace_func([], outputs)() variables = inplace_func([], outputs)()
if len(variables) == 1: if len(variables) == 1:
...@@ -2611,48 +2610,55 @@ def test_autocast(): ...@@ -2611,48 +2610,55 @@ def test_autocast():
finally: finally:
ac.__exit__() ac.__exit__()
def test_unbroadcast_addbroadcast(): class test_broadcast(unittest.TestCase):
""" def test_broadcast_bigdim(self):
test that the unbroadcast fct don't insert not needed broadcast def f():
and fuse consecutive Rebroadcast op x = matrix()
""" addbroadcast(x,2)
self.failUnlessRaises(ValueError, f)
x=matrix()
assert unbroadcast(x,0) is x def test_unbroadcast_addbroadcast(self):
assert unbroadcast(x,1) is x """
assert unbroadcast(x,1,0) is x test that the unbroadcast fct don't insert not needed broadcast
assert unbroadcast(x,0,1) is x and fuse consecutive Rebroadcast op
"""
assert addbroadcast(x,0) is not x
assert addbroadcast(x,1) is not x x=matrix()
assert addbroadcast(x,1,0).owner.inputs[0] is x assert unbroadcast(x,0) is x
assert unbroadcast(x,1) is x
assert unbroadcast(addbroadcast(x,0),0) is x assert unbroadcast(x,1,0) is x
assert addbroadcast(unbroadcast(x,0),0) is not x assert unbroadcast(x,0,1) is x
x=row()
assert unbroadcast(x,0) is not x assert addbroadcast(x,0) is not x
assert unbroadcast(x,1) is x assert addbroadcast(x,1) is not x
assert unbroadcast(x,1,0) is not x assert addbroadcast(x,1,0).owner.inputs[0] is x
assert unbroadcast(x,0,1) is not x
assert unbroadcast(addbroadcast(x,0),0) is x
assert addbroadcast(x,0) is x assert addbroadcast(unbroadcast(x,0),0) is not x
assert addbroadcast(x,1).owner.inputs[0] is x x=row()
assert addbroadcast(x,1,0).owner.inputs[0] is x assert unbroadcast(x,0) is not x
assert addbroadcast(x,0,1).owner.inputs[0] is x assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is not x
assert unbroadcast(addbroadcast(x,1),1) is x assert unbroadcast(x,0,1) is not x
assert addbroadcast(unbroadcast(x,1),1) is not x
assert addbroadcast(x,0) is x
#the first broadcast is remove the broadcast, so the second assert addbroadcast(x,1).owner.inputs[0] is x
#should not make one assert addbroadcast(x,1,0).owner.inputs[0] is x
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x assert addbroadcast(x,0,1).owner.inputs[0] is x
#test that consecutive Rebroadcast op are fused assert unbroadcast(addbroadcast(x,1),1) is x
x=TensorType(dtype = 'float64', broadcastable = (True,True))() assert addbroadcast(unbroadcast(x,1),1) is not x
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x #the first broadcast is remove the broadcast, so the second
assert addbroadcast(unbroadcast(x,0),0) is x #should not make one
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x
#test that consecutive Rebroadcast op are fused
x=TensorType(dtype = 'float64', broadcastable = (True,True))()
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x
def test_mod(): def test_mod():
""" """
We add this test as not all language and C implementation give the same We add this test as not all language and C implementation give the same
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论