提交 9300a882 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

merge

...@@ -2705,6 +2705,8 @@ class Rebroadcast(Op): ...@@ -2705,6 +2705,8 @@ class Rebroadcast(Op):
broadcast_pattern[k] = str(int(v)) broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern)) return '%s{%s}' % (self.__class__.__name__, ','.join(broadcast_pattern))
def make_node(self, x): def make_node(self, x):
if x.ndim <= numpy.max(self.axis.keys()):
raise ValueError('Trying to rebroadcast inexistant dimension')
t = x.type.__class__(dtype = x.type.dtype, t = x.type.__class__(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b) broadcastable = [self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)]) for i, b in enumerate(x.type.broadcastable)])
......
...@@ -27,7 +27,6 @@ utt.seed_rng() ...@@ -27,7 +27,6 @@ utt.seed_rng()
def inplace_func(inputs, outputs, mode=get_default_mode()): def inplace_func(inputs, outputs, mode=get_default_mode()):
return function(inputs, outputs, mode=mode, accept_inplace=True) return function(inputs, outputs, mode=mode, accept_inplace=True)
def eval_outputs(outputs): def eval_outputs(outputs):
variables = inplace_func([], outputs)() variables = inplace_func([], outputs)()
if len(variables) == 1: if len(variables) == 1:
...@@ -2611,7 +2610,14 @@ def test_autocast(): ...@@ -2611,7 +2610,14 @@ def test_autocast():
finally: finally:
ac.__exit__() ac.__exit__()
def test_unbroadcast_addbroadcast(): class test_broadcast(unittest.TestCase):
def test_broadcast_bigdim(self):
def f():
x = matrix()
addbroadcast(x,2)
self.failUnlessRaises(ValueError, f)
def test_unbroadcast_addbroadcast(self):
""" """
test that the unbroadcast fct don't insert not needed broadcast test that the unbroadcast fct don't insert not needed broadcast
and fuse consecutive Rebroadcast op and fuse consecutive Rebroadcast op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论